metalstm.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:FewShotLearning 作者: gitabcworld 项目源码 文件源码
def _forward_rnn(cell, input_, grads_, length, hx):
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            hx = cell(input_=input_[time],grads_=grads_[time], hx=hx)
            #mask = (time < length).float().unsqueeze(1).expand_as(h_next[0])
            #fS_next = h_next[0] * mask + hx[0] * (1 - mask)
            #iS_next = h_next[1] * mask + hx[1] * (1 - mask)
            #cS_next = h_next[2] * mask + hx[2] * (1 - mask)
            #deltaS_next = h_next[3] * mask + hx[3] * (1 - mask)
            #hx_next = (fS_next, iS_next, cS_next, deltaS_next)
            #output.append(h_next)
            #hx = hx_next
        #output = torch.stack(output, 0)
        #return output,hx
        #return hx[2],hx
        return hx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号