model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:BuboQA 作者: castorini 项目源码 文件源码
def forward(self, batch):
        # shape of batch (sequence length, batch size)
        inputs = self.embed(batch.question) # shape (sequence length, batch_size, dimension of embedding)
        batch_size = inputs.size()[1]
        state_shape = self.config.n_cells, batch_size, self.config.d_hidden
        if self.config.rnn_type.lower() == 'gru':
            h0 = autograd.Variable(inputs.data.new(*state_shape).zero_())
            outputs, ht = self.rnn(inputs, h0)
        else:
            h0 = c0 = autograd.Variable(inputs.data.new(*state_shape).zero_())
            outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
        # shape of `outputs` - (sequence length, batch size, hidden size X num directions)
        tags = self.hidden2tag(outputs.view(-1, outputs.size(2)))
        # print(tags)
        scores = F.log_softmax(tags)
        return scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号