def train(args):
model = JaCCGEmbeddingTagger(args.model,
args.word_emb_size, args.char_emb_size)
if args.initmodel:
print('Load model from', args.initmodel)
chainer.serializers.load_npz(args.initmodel, model)
if args.pretrained:
print('Load pretrained word embeddings from', args.pretrained)
model.load_pretrained_embeddings(args.pretrained)
train = JaCCGTaggerDataset(args.model, args.train)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
val = JaCCGTaggerDataset(args.model, args.val)
val_iter = chainer.iterators.SerialIterator(
val, args.batchsize, repeat=False, shuffle=False)
optimizer = chainer.optimizers.AdaGrad()
optimizer.setup(model)
# optimizer.add_hook(WeightDecay(1e-8))
my_converter = lambda x, dev: convert.concat_examples(x, dev, (None,-1,None,None))
updater = training.StandardUpdater(train_iter, optimizer, converter=my_converter)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.model)
val_interval = 1000, 'iteration'
log_interval = 200, 'iteration'
eval_model = model.copy()
eval_model.train = False
trainer.extend(extensions.Evaluator(
val_iter, eval_model, my_converter), trigger=val_interval)
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=val_interval)
trainer.extend(extensions.snapshot_object(
model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy',
]), trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.run()
评论列表
文章目录