推論(翻訳)

OpenNMT-pyの学習処理のソースコードリーディングを行う。

以前、以下のコマンドでGlobalAttentionを使ったRNNモデルでの学習を行った。

cd ~/OpenNMT-py # 翻訳を実行 onmt_translate -model demo-model_step_10000.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose

OpenNMT-py/onmt/bin/translate.py

onmt_translateコマンドは、OpenNMT-py/onmt/bin/translate.pyを呼び出している。

def _get_parser(): parser = ArgumentParser(description='translate.py') opts.config_opts(parser) opts.translate_opts(parser) return parser def main(): parser = _get_parser() opt = parser.parse_args() translate(opt) if __name__ == "__main__": main()

学習処理と同様に、コマンドパラメータを取得して、メイン処理のtranslate()を呼び出している。

translate処理について、コード中にコメントを入れた。

def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) # translatorの作成 translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) # 翻訳処理の実行 translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug )

サブモジュール