urnn.py 文件源码

python
阅读 76 收藏 0 点赞 0 评论 0

项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码
def _forward_rnn(cell, input_, length, hx):
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            h_next = cell(input_=input_[time], hx=hx)
            # mask = (time < length).float().unsqueeze(1).expand_as(h_next)
            # h_next = h_next*mask + hx*(1 - mask)
            output.append(h_next)
        output = torch.stack(output, 0)
        return output, h_next
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号