bnlstm.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:FewShotLearning 作者: gitabcworld 项目源码 文件源码
def forward(self, input_, length=None, hx=None):
        if self.batch_first:
            input_ = input_.transpose(0, 1)
        max_time, batch_size, _ = input_.size()
        if length is None:
            length = Variable(torch.LongTensor([max_time] * batch_size))
            if input_.is_cuda:
                length = length.cuda()
        if hx is None:
            hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_())
            hx = (hx, hx)
        h_n = []
        c_n = []
        layer_output = None
        for layer in range(self.num_layers):
            layer_output, (layer_h_n, layer_c_n) = LSTM._forward_rnn(
                cell=self.cells[layer], input_=input_, length=length, hx=hx)
            input_ = self.dropout_layer(layer_output)
            h_n.append(layer_h_n)
            c_n.append(layer_c_n)
        output = layer_output
        h_n = torch.stack(h_n, 0)
        c_n = torch.stack(c_n, 0)
        return output, (h_n, c_n)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号