japanese_tagger.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def train(args):
    model = JaCCGEmbeddingTagger(args.model,
                args.word_emb_size, args.char_emb_size)
    if args.initmodel:
        print('Load model from', args.initmodel)
        chainer.serializers.load_npz(args.initmodel, model)

    if args.pretrained:
        print('Load pretrained word embeddings from', args.pretrained)
        model.load_pretrained_embeddings(args.pretrained)

    train = JaCCGTaggerDataset(args.model, args.train)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    val = JaCCGTaggerDataset(args.model, args.val)
    val_iter = chainer.iterators.SerialIterator(
            val, args.batchsize, repeat=False, shuffle=False)
    optimizer = chainer.optimizers.AdaGrad()
    optimizer.setup(model)
    # optimizer.add_hook(WeightDecay(1e-8))
    my_converter = lambda x, dev: convert.concat_examples(x, dev, (None,-1,None,None))
    updater = training.StandardUpdater(train_iter, optimizer, converter=my_converter)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.model)

    val_interval = 1000, 'iteration'
    log_interval = 200, 'iteration'

    eval_model = model.copy()
    eval_model.train = False

    trainer.extend(extensions.Evaluator(
        val_iter, eval_model, my_converter), trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy',
    ]), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号