seq2vec.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:vqa.pytorch 作者: Cadene 项目源码 文件源码
def forward(self, input):
        lengths = process_lengths(input)
        x = self.embedding(input) # seq2seq
        x = getattr(F, 'tanh')(x)
        x_0, hn = self.rnn_0(x)
        vec_0 = select_last(x_0, lengths)

        # x_1 = F.dropout(x_0, p=0.3, training=self.training)
        # print(x_1.size())
        x_1, hn = self.rnn_1(x_0)
        vec_1 = select_last(x_1, lengths)

        vec_0 = F.dropout(vec_0, p=0.3, training=self.training)
        vec_1 = F.dropout(vec_1, p=0.3, training=self.training)
        output = torch.cat((vec_0, vec_1), 1)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号