auto_reg_nn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def __init__(self, input_dim, hidden_dim, output_dim_multiplier=1,
                 mask_encoding=None, permutation=None):
        super(AutoRegressiveNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim_multiplier = output_dim_multiplier

        if mask_encoding is None:
            # the dependency structure is chosen at random
            self.mask_encoding = 1 + torch_multinomial(torch.ones(input_dim - 1) / (input_dim - 1),
                                                       num_samples=hidden_dim, replacement=True)
        else:
            # the dependency structure is given by the user
            self.mask_encoding = mask_encoding

        if permutation is None:
            # a permutation is chosen at random
            self.permutation = torch.randperm(input_dim)
        else:
            # the permutation is chosen by the user
            self.permutation = permutation

        # these masks control the autoregressive structure
        self.mask1 = Variable(torch.zeros(hidden_dim, input_dim))
        self.mask2 = Variable(torch.zeros(input_dim * self.output_dim_multiplier, hidden_dim))

        for k in range(hidden_dim):
            # fill in mask1
            m_k = self.mask_encoding[k]
            slice_k = torch.cat([torch.ones(m_k), torch.zeros(input_dim - m_k)])
            for j in range(input_dim):
                self.mask1[k, self.permutation[j]] = slice_k[j]
            # fill in mask2
            slice_k = torch.cat([torch.zeros(m_k), torch.ones(input_dim - m_k)])
            for r in range(self.output_dim_multiplier):
                for j in range(input_dim):
                    self.mask2[r * input_dim + self.permutation[j], k] = slice_k[j]

        self.lin1 = MaskedLinear(input_dim, hidden_dim, self.mask1)
        self.lin2 = MaskedLinear(hidden_dim, input_dim * output_dim_multiplier, self.mask2)
        self.relu = nn.ReLU()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号