model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:loop 作者: facebookresearch 项目源码 文件源码
def forward(self, x, ident, context, start=True):
        out, attns = [], []
        o_t = x[0]
        self.init_buffer(ident, start)

        for o_tm1 in torch.split(x, 1):
            if not self.training:
                o_tm1 = o_t.unsqueeze(0)

            # predict weighted context based on S
            c_t, mu_t, alpha_t = self.attn(self.S_t,
                                           context.transpose(0, 1),
                                           self.mu_t)

            # advance mu and update buffer
            self.S_t = self.update_buffer(self.S_t, c_t, o_tm1, ident)
            self.mu_t = mu_t

            # predict next time step based on buffer content
            ot_out = self.N_o(self.S_t.view(self.S_t.size(0), -1))
            sp_out = self.F_o(ident)
            o_t = self.output(ot_out + sp_out)

            out += [o_t]
            attns += [alpha_t.squeeze()]

        out_seq = torch.stack(out)
        attns_seq = torch.stack(attns)

        return out_seq, attns_seq
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号