seq2vec.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vqa.pytorch 作者: Cadene 项目源码 文件源码
def factory(vocab_words, opt):
    if opt['arch'] == 'skipthoughts':
        st_class = getattr(skipthoughts, opt['type'])
        seq2vec = st_class(opt['dir_st'],
                           vocab_words,
                           dropout=opt['dropout'],
                           fixed_emb=opt['fixed_emb'])
    elif opt['arch'] == '2-lstm':
        seq2vec = TwoLSTM(vocab_words,
                          opt['emb_size'],
                          opt['hidden_size'])
    elif opt['arch'] == 'lstm':
        seq2vec = TwoLSTM(vocab_words,
                          opt['emb_size'],
                          opt['hidden_size'],
                          opt['num_layers'])
    else:
        raise NotImplementedError
    return seq2vec
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号