def __init__(self, input_dim, hidden_dim, output_dim_multiplier=1,
mask_encoding=None, permutation=None):
super(AutoRegressiveNN, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim_multiplier = output_dim_multiplier
if mask_encoding is None:
# the dependency structure is chosen at random
self.mask_encoding = 1 + torch_multinomial(torch.ones(input_dim - 1) / (input_dim - 1),
num_samples=hidden_dim, replacement=True)
else:
# the dependency structure is given by the user
self.mask_encoding = mask_encoding
if permutation is None:
# a permutation is chosen at random
self.permutation = torch.randperm(input_dim)
else:
# the permutation is chosen by the user
self.permutation = permutation
# these masks control the autoregressive structure
self.mask1 = Variable(torch.zeros(hidden_dim, input_dim))
self.mask2 = Variable(torch.zeros(input_dim * self.output_dim_multiplier, hidden_dim))
for k in range(hidden_dim):
# fill in mask1
m_k = self.mask_encoding[k]
slice_k = torch.cat([torch.ones(m_k), torch.zeros(input_dim - m_k)])
for j in range(input_dim):
self.mask1[k, self.permutation[j]] = slice_k[j]
# fill in mask2
slice_k = torch.cat([torch.zeros(m_k), torch.ones(input_dim - m_k)])
for r in range(self.output_dim_multiplier):
for j in range(input_dim):
self.mask2[r * input_dim + self.permutation[j], k] = slice_k[j]
self.lin1 = MaskedLinear(input_dim, hidden_dim, self.mask1)
self.lin2 = MaskedLinear(hidden_dim, input_dim * output_dim_multiplier, self.mask2)
self.relu = nn.ReLU()
评论列表
文章目录