推論(翻訳)
OpenNMT-pyの学習処理のソースコードリーディングを行う。
以前、以下のコマンドでGlobalAttentionを使ったRNNモデルでの学習を行った。
cd ~/OpenNMT-py # 翻訳を実行 onmt_translate -model demo-model_step_10000.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
OpenNMT-py/onmt/bin/translate.py
onmt_translateコマンドは、OpenNMT-py/onmt/bin/translate.pyを呼び出している。
def _get_parser(): parser = ArgumentParser(description='translate.py') opts.config_opts(parser) opts.translate_opts(parser) return parser def main(): parser = _get_parser() opt = parser.parse_args() translate(opt) if __name__ == "__main__": main()
学習処理と同様に、コマンドパラメータを取得して、メイン処理のtranslate()を呼び出している。
translate処理について、コード中にコメントを入れた。
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) # translatorの作成 translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) # 翻訳処理の実行 translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug )