recurrent.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def forward(self, inputs, hidden):
        def select_layer(h_state, i):  # To work on both LSTM / GRU, RNN
            if isinstance(h_state, tuple):
                return tuple([select_layer(s, i) for s in h_state])
            else:
                return h_state[i]

        next_hidden = []
        for i, layer in enumerate(self.layers):
            hidden_i = select_layer(hidden, i)
            next_hidden_i = layer(inputs, hidden_i)
            output = next_hidden_i[0] if isinstance(next_hidden_i, tuple) \
                else next_hidden_i
            if i + 1 != self.num_layers:
                output = self.dropout(output)
            if i > 0 and self.residual:
                inputs = output + inputs
            else:
                inputs = output
            next_hidden.append(next_hidden_i)
        if isinstance(hidden, tuple):
            next_hidden = tuple([torch.stack(h) for h in zip(*next_hidden)])
        else:
            next_hidden = torch.stack(next_hidden)
        return inputs, next_hidden
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号