translate.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:chainer-qrnn 作者: musyoku 项目源码 文件源码
def main(args):
    vocab, vocab_inv = load_vocab(args.model_dir)
    vocab_source, vocab_target = vocab
    vocab_inv_source, vocab_inv_target = vocab_inv

    source_dataset, target_dataset = read_data(vocab_source, vocab_target, args.source_train, None, args.source_dev, None, args.source_test, None, reverse_source=True)

    source_dataset_train, source_dataset_dev, source_dataset_test = source_dataset
    target_dataset_train, target_dataset_dev, target_dataset_test = target_dataset
    printb("data    #")
    if len(source_dataset_train) > 0:
        print("train    {}".format(len(source_dataset_train)))
    if len(source_dataset_dev) > 0:
        print("dev  {}".format(len(source_dataset_dev)))
    if len(source_dataset_test) > 0:
        print("test {}".format(len(source_dataset_test)))


    # split into buckets
    source_buckets_train = None
    if len(source_dataset_train) > 0:
        printb("buckets     #data   (train)")
        source_buckets_train = make_buckets(source_dataset_train)
        if args.buckets_slice is not None:
            source_buckets_train = source_buckets_train[:args.buckets_slice + 1]
        for size, data in zip(bucket_sizes, source_buckets_train):
            print("{}   {}".format(size, len(data)))

    source_buckets_dev = None
    if len(source_dataset_dev) > 0:
        printb("buckets     #data   (dev)")
        source_buckets_dev = make_buckets(source_dataset_dev)
        if args.buckets_slice is not None:
            source_buckets_dev = source_buckets_dev[:args.buckets_slice + 1]
        for size, data in zip(bucket_sizes, source_buckets_dev):
            print("{}   {}".format(size, len(data)))

    source_buckets_test = None
    if len(source_dataset_test) > 0:
        printb("buckets     #data   (test)")
        source_buckets_test = make_buckets(source_dataset_test)
        if args.buckets_slice is not None:
            source_buckets_test = source_buckets_test[:args.buckets_slice + 1]
        for size, data in zip(bucket_sizes, source_buckets_test):
            print("{}   {}".format(size, len(data)))

    # init
    model = load_model(args.model_dir)
    assert model is not None
    if args.gpu_device >= 0:
        cuda.get_device(args.gpu_device).use()
        model.to_gpu()

    if source_buckets_train is not None:
        dump_source_translation(model, source_buckets_train, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)

    if source_buckets_dev is not None:
        dump_source_translation(model, source_buckets_dev, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)

    if source_buckets_test is not None:
        dump_source_translation(model, source_buckets_test, vocab_inv_source, vocab_inv_target, beam_width=args.beam_width, normalization_alpha=args.alpha)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号