def train(args):
time_start = timer()
if args.gpu >= 0:
chainer.cuda.get_device_from_id(args.gpu).use()
cuda.check_cuda_available()
if args.path_vocab == '':
vocab = create_from_dir(args.path_corpus)
else:
vocab = Vocabulary()
vocab.load(args.path_vocab)
logger.info("loaded vocabulary")
if args.context_representation != 'word': # for deps or ner context representation, we need a new context vocab for NS or HSM loss function.
vocab_context = create_from_annotated_dir(args.path_corpus, representation=args.context_representation)
else :
vocab_context = vocab
loss_func = get_loss_func(args, vocab_context)
model = get_model(args, loss_func, vocab)
if args.gpu >= 0:
model.to_gpu()
logger.debug("model sent to gpu")
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
if os.path.isfile(args.path_corpus):
train, val = get_data(args.path_corpus, vocab)
if args.test:
train = train[:100]
val = val[:100]
train_iter = WindowIterator(train, args.window, args.batchsize)
val_iter = WindowIterator(val, args.window, args.batchsize, repeat=False)
else:
train_iter = DirWindowIterator(path=args.path_corpus, vocab=vocab, window_size=args.window, batch_size=args.batchsize)
updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.path_out)
if os.path.isfile(args.path_corpus):
trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=args.gpu))
trainer.extend(extensions.LogReport())
if os.path.isfile(args.path_corpus):
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
else:
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'elapsed_time']))
# trainer.extend(extensions.ProgressBar())
trainer.run()
model = create_model(args, model, vocab)
time_end = timer()
model.metadata["execution_time"] = time_end - time_start
return model
评论列表
文章目录