train_word2vec.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vsmlib 作者: undertherain 项目源码 文件源码
def train(args):
    time_start = timer()
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        cuda.check_cuda_available()


    if args.path_vocab == '':
        vocab = create_from_dir(args.path_corpus)
    else:
        vocab = Vocabulary()
        vocab.load(args.path_vocab)
        logger.info("loaded vocabulary")

    if args.context_representation != 'word': # for deps or ner context representation, we need a new context vocab for NS or HSM loss function.
        vocab_context = create_from_annotated_dir(args.path_corpus, representation=args.context_representation)
    else :
        vocab_context = vocab

    loss_func = get_loss_func(args, vocab_context)
    model = get_model(args, loss_func, vocab)

    if args.gpu >= 0:
        model.to_gpu()
        logger.debug("model sent to gpu")

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    if os.path.isfile(args.path_corpus):
        train, val = get_data(args.path_corpus, vocab)
        if args.test:
            train = train[:100]
            val = val[:100]
        train_iter = WindowIterator(train, args.window, args.batchsize)
        val_iter = WindowIterator(val, args.window, args.batchsize, repeat=False)
    else:
        train_iter = DirWindowIterator(path=args.path_corpus, vocab=vocab, window_size=args.window, batch_size=args.batchsize)
    updater = training.StandardUpdater(train_iter, optimizer, converter=convert, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.path_out)

    if os.path.isfile(args.path_corpus):
        trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=args.gpu))
    trainer.extend(extensions.LogReport())
    if os.path.isfile(args.path_corpus):
        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
    else:
        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'elapsed_time']))
    # trainer.extend(extensions.ProgressBar())
    trainer.run()
    model = create_model(args, model, vocab)
    time_end = timer()
    model.metadata["execution_time"] = time_end - time_start
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号