Merge.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def Decoder(self, input, hidden_encoder, phis,
                input_target=None, target=None):
        feed_target = False
        if target is not None:
            feed_target = True
        # N_n is the number of elements of the scope of the n-th element
        N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
                                                      self.hidden_size)
        output = (Variable(torch.ones(self.batch_size, self.n, self.n))
                  .type(dtype))
        index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1)
        hidden = (torch.gather(hidden_encoder, 1, index)).squeeze()
        # W1xe size: (batch_size, n + 1, hidden_size)
        W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
                self.batch_size, self.hidden_size, self.hidden_size)))
        # init token
        start = (self.init_token.unsqueeze(0).expand(self.batch_size,
                 self.input_size))
        input_step = start
        for n in xrange(self.n):
            # decouple interaction between different scopes by looking at
            # subdiagonal elements of Phi
            if n > 0:
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.hidden_size))
                index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
                         .unsqueeze(1))
                init_hidden = (torch.gather(hidden_encoder, 1, index)
                               .squeeze())
                hidden = t * hidden + (1 - t) * init_hidden
                t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
                     self.batch_size, self.input_size))
                input_step = t * input_step + (1 - t) * start
            # Compute next state
            hidden = self.decoder_cell(input_step, hidden)
            # Compute pairwise interactions
            u = self.attention(hidden, W1xe, hidden_encoder, tanh=True)
            # Normalize interactions by taking the masked softmax by phi
            attn = self.softmax_m(phis[:, n].squeeze(), u)
            if feed_target:
                # feed next step with target
                next = (target[:, n].unsqueeze(1).unsqueeze(2)
                        .expand(self.batch_size, 1, self.input_size)
                        .type(dtype_l))
                input_step = torch.gather(input_target, 1, next).squeeze()
            else:
                # blend inputs
                input_step = (torch.sum(attn.unsqueeze(2).expand(
                              self.batch_size, self. n,
                              self.input_size) * input, 1)).squeeze()
            # Update output
            output[:, n] = attn
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号