def main(args):
torch.set_num_threads(5)
if args.method == 'cbow':
word2vec = Word2Vec(input_file_name=args.input_file_name,
output_file_name=args.output_file_name,
emb_dimension=args.emb_dimension,
batch_size=args.batch_size,
# windows_size used by Skip-Gram model
window_size=args.window_size,
iteration=args.iteration,
initial_lr=args.initial_lr,
min_count=args.min_count,
using_hs=args.using_hs,
using_neg=args.using_neg,
# context_size used by CBOW model
context_size=args.context_size,
hidden_size=args.hidden_size,
cbow=True,
skip_gram=False)
word2vec.cbow_train()
elif args.method == 'skip_gram':
word2vec = Word2Vec(input_file_name=args.input_file_name,
output_file_name=args.output_file_name,
emb_dimension=args.emb_dimension,
batch_size=args.batch_size,
# windows_size used by Skip-Gram model
window_size=args.window_size,
iteration=args.iteration,
initial_lr=args.initial_lr,
min_count=args.min_count,
using_hs=args.using_hs,
using_neg=args.using_neg,
# context_size used by CBOW model
context_size=args.context_size,
hidden_size=args.hidden_size,
cbow=False,
skip_gram=True)
word2vec.skip_gram_train()
评论列表
文章目录