model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:loop 作者: facebookresearch 项目源码 文件源码
def update_buffer(self, S_tm1, c_t, o_tm1, ident):
        # concat previous output & context
        idt = torch.tanh(self.F_u(ident))
        o_tm1 = o_tm1.squeeze(0)
        z_t = torch.cat([c_t + idt, o_tm1/30], 1)
        z_t = z_t.unsqueeze(2)
        Sp = torch.cat([z_t, S_tm1[:, :, :-1]], 2)

        # update S
        u = self.N_u(Sp.view(Sp.size(0), -1))
        u[:, :idt.size(1)] = u[:, :idt.size(1)] + idt
        u = u.unsqueeze(2)
        S = torch.cat([u, S_tm1[:, :, :-1]], 2)

        return S
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号