def forward(self, h, Q, u):
batch_size = h.size()[0]
v, r = self.trans(h).chunk(2, dim=1)
v1 = v.unsqueeze(2)
rT = r.unsqueeze(1)
I = Variable(torch.eye(self.dim_z).repeat(batch_size, 1, 1))
if rT.data.is_cuda:
I.dada.cuda()
A = I.add(v1.bmm(rT))
B = self.fc_B(h).view(-1, self.dim_z, self.dim_u)
o = self.fc_o(h)
# need to compute the parameters for distributions
# as well as for the samples
u = u.unsqueeze(2)
d = A.bmm(Q.mu.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
sample = A.bmm(h.unsqueeze(2)).add(B.bmm(u)).add(o).squeeze(2)
return sample, NormalDistribution(d, Q.sigma, Q.logsigma, v=v, r=r)
评论列表
文章目录