def run_training(args):
out_dir = pathlib.Path(args.directory)
sentences = dataset.load(args.source)
if args.epoch is not None:
start = args.epoch + 1
storage = load(out_dir, args.epoch)
sentences = itertools.islice(sentences, start, None)
else:
start = 0
storage = init(args)
if (out_dir/meta_name).exists():
if input('Overwrite? [y/N]: ').strip().lower() != 'y':
exit(1)
with (out_dir/meta_name).open('wb') as f:
np.save(f, [storage])
batchsize = 5000
for i, sentence in enumerate(sentences, start):
if i % batchsize == 0:
print()
serializers.save_npz(
str(out_dir/model_name(i)),
storage.model
)
serializers.save_npz(
str(out_dir/optimizer_name(i)),
storage.optimizer
)
else:
print(
util.progress(
'batch {}'.format(i // batchsize),
(i % batchsize) / batchsize, 100),
end=''
)
train(storage.model,
storage.optimizer,
generate_data(sentence),
generate_label(sentence),
generate_attr(
sentence,
storage.mappings
)
)
评论列表
文章目录