def forward(self, prev_samples, upper_tier_conditioning):
(batch_size, _, _) = upper_tier_conditioning.size()
prev_samples = self.embedding(
prev_samples.contiguous().view(-1)
).view(
batch_size, -1, self.q_levels
)
prev_samples = prev_samples.permute(0, 2, 1)
upper_tier_conditioning = upper_tier_conditioning.permute(0, 2, 1)
x = F.relu(self.input(prev_samples) + upper_tier_conditioning)
x = F.relu(self.hidden(x))
x = self.output(x).permute(0, 2, 1).contiguous()
return F.log_softmax(x.view(-1, self.q_levels)) \
.view(batch_size, -1, self.q_levels)
评论列表
文章目录