run.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:pytorch-nlp 作者: endymecy 项目源码 文件源码
def main(args):
    torch.set_num_threads(5)
    if args.method == 'cbow':
        word2vec = Word2Vec(input_file_name=args.input_file_name,
                            output_file_name=args.output_file_name,
                            emb_dimension=args.emb_dimension,
                            batch_size=args.batch_size,
                            # windows_size used by Skip-Gram model
                            window_size=args.window_size,
                            iteration=args.iteration,
                            initial_lr=args.initial_lr,
                            min_count=args.min_count,
                            using_hs=args.using_hs,
                            using_neg=args.using_neg,
                            # context_size used by CBOW model
                            context_size=args.context_size,
                            hidden_size=args.hidden_size,
                            cbow=True,
                            skip_gram=False)
        word2vec.cbow_train()
    elif args.method == 'skip_gram':
        word2vec = Word2Vec(input_file_name=args.input_file_name,
                            output_file_name=args.output_file_name,
                            emb_dimension=args.emb_dimension,
                            batch_size=args.batch_size,
                            # windows_size used by Skip-Gram model
                            window_size=args.window_size,
                            iteration=args.iteration,
                            initial_lr=args.initial_lr,
                            min_count=args.min_count,
                            using_hs=args.using_hs,
                            using_neg=args.using_neg,
                            # context_size used by CBOW model
                            context_size=args.context_size,
                            hidden_size=args.hidden_size,
                            cbow=False,
                            skip_gram=True)
        word2vec.skip_gram_train()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号