def factory(vocab_words, opt):
if opt['arch'] == 'skipthoughts':
st_class = getattr(skipthoughts, opt['type'])
seq2vec = st_class(opt['dir_st'],
vocab_words,
dropout=opt['dropout'],
fixed_emb=opt['fixed_emb'])
elif opt['arch'] == '2-lstm':
seq2vec = TwoLSTM(vocab_words,
opt['emb_size'],
opt['hidden_size'])
elif opt['arch'] == 'lstm':
seq2vec = TwoLSTM(vocab_words,
opt['emb_size'],
opt['hidden_size'],
opt['num_layers'])
else:
raise NotImplementedError
return seq2vec
评论列表
文章目录