Split.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def forward(self, input_n, hidden, phi, nh):
        hidden = torch.cat((hidden, input_n), 2)
        # Aggregate reresentations
        h_conv = torch.div(torch.bmm(phi, hidden), nh)
        hidden = hidden.view(-1, self.hidden_size + self.input_size)
        h_conv = h_conv.view(-1, self.hidden_size + self.input_size)
        # h_conv has shape (batch_size, n, hidden_size + input_size)
        m1 = (torch.mm(hidden, self.W1)
              .view(self.batch_size, -1, self.hidden_size))
        m2 = (torch.mm(h_conv, self.W2)
              .view(self.batch_size, -1, self.hidden_size))
        m3 = self.b.unsqueeze(0).unsqueeze(1).expand_as(m2)
        hidden = torch.sigmoid(m1 + m2 + m3)
        return hidden
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号