Split.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def forward(self, e, input, mask, scale=0):
        hidden = Variable(torch.randn(self.batch_size, self.n,
                                      self.hidden_size)).type(dtype)
        if scale == 0:
            e = Variable(torch.zeros(self.batch_size, self.n)).type(dtype)
        Phi = self.build_Phi(e, mask)
        N = torch.sum(Phi, 2).squeeze()
        N += (N == 0).type(dtype)  # avoid division by zero
        Nh = N.unsqueeze(2).expand(self.batch_size, self.n,
                                   self.hidden_size + self.input_size)
        # Normalize inputs, important part!
        mask_inp = mask.unsqueeze(2).expand_as(input)
        input_n = self.Normalize_inputs(Phi, input) * mask_inp
        # input_n = input * mask_inp
        for i, layer in enumerate(self.layers):
            hidden = layer(input_n, hidden, Phi, Nh)
        hidden_p = hidden.view(self.batch_size * self.n, self.hidden_size)
        scores = self.linear_b(hidden_p)
        probs = torch.sigmoid(scores).view(self.batch_size, self.n) * mask
        # probs has shape (batch_size, n)
        return scores, probs, input_n, Phi
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号