DCN.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def fwd_split(self, input, batch, depth,
                  random_split=False, mode='train', epoch=0):
        length = self.split.n
        var = 0.0
        # Iterate over scales
        e = Variable(torch.zeros(self.batch_size, length)).type(dtype)
        mask = (input[:, :, 0] >= 0).type(dtype).squeeze()
        Phis, Bs, Inputs_N, Samples = ([] for ii in xrange(4))
        for scale in xrange(depth):
            logits, probs, input_n, Phi = self.split(e, input,
                                                     mask, scale=scale)
            # Sample from probabilities and update embeddings
            if random_split:
                rand = (Variable(torch.zeros(self.batch_size, length))
                        .type(dtype))
                init.uniform(rand)
                sample = (rand > 0.5).type(dtype)
            else:
                rand = (Variable(torch.zeros(self.batch_size, length))
                        .type(dtype))
                init.uniform(rand)
                sample = (probs > rand).type(dtype)
            e = 2 * e + sample
            # Appends
            Samples.append(sample)
            Phis.append(Phi)
            Bs.append(probs)
            Inputs_N.append(input_n)
            # variance of bernouilli probabilities
            var += self.compute_variance(probs, mask)
        # computes log probabilities of binary actions for the policy gradient
        Log_Probs = self.log_probabilities(Bs, Samples, mask, depth)
        # pad embeddings with infinity to not affect embeddings argsort
        infty = 1e6
        e = e * mask + (1 - mask) * infty
        return var, Phis, Bs, Inputs_N, e, Log_Probs

    ###########################################################################
    #                             Merge Phase                                 #
    ###########################################################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号