def main(args):
vocab, vocab_inv = load_vocab(args.model_dir)
vocab_source, vocab_target = vocab
vocab_inv_source, vocab_inv_target = vocab_inv
source_dataset, target_dataset = read_data(vocab_source, vocab_target, args.source_train, None, args.source_dev, None, args.source_test, None, reverse_source=True)
source_dataset_train, source_dataset_dev, source_dataset_test = source_dataset
target_dataset_train, target_dataset_dev, target_dataset_test = target_dataset
printb("data #")
if len(source_dataset_train) > 0:
print("train {}".format(len(source_dataset_train)))
if len(source_dataset_dev) > 0:
print("dev {}".format(len(source_dataset_dev)))
if len(source_dataset_test) > 0:
print("test {}".format(len(source_dataset_test)))
# split into buckets
source_buckets_train = None
if len(source_dataset_train) > 0:
printb("buckets #data (train)")
source_buckets_train = make_buckets(source_dataset_train)
if args.buckets_slice is not None:
source_buckets_train = source_buckets_train[:args.buckets_slice + 1]
for size, data in zip(bucket_sizes, source_buckets_train):
print("{} {}".format(size, len(data)))
source_buckets_dev = None
if len(source_dataset_dev) > 0:
printb("buckets #data (dev)")
source_buckets_dev = make_buckets(source_dataset_dev)
if args.buckets_slice is not None:
source_buckets_dev = source_buckets_dev[:args.buckets_slice + 1]
for size, data in zip(bucket_sizes, source_buckets_dev):
print("{} {}".format(size, len(data)))
source_buckets_test = None
if len(source_dataset_test) > 0:
printb("buckets #data (test)")
source_buckets_test = make_buckets(source_dataset_test)
if args.buckets_slice is not None:
source_buckets_test = source_buckets_test[:args.buckets_slice + 1]
for size, data in zip(bucket_sizes, source_buckets_test):
print("{} {}".format(size, len(data)))
# init
model = load_model(args.model_dir)
assert model is not None
if args.gpu_device >= 0:
cuda.get_device(args.gpu_device).use()
model.to_gpu()
if source_buckets_train is not None:
dump_source_translation(model, source_buckets_train, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)
if source_buckets_dev is not None:
dump_source_translation(model, source_buckets_dev, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)
if source_buckets_test is not None:
dump_source_translation(model, source_buckets_test, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)
评论列表
文章目录