def forward(self, inputs, batch_size, hidden_cell=None):
if hidden_cell is None:
# then must init with zeros
if use_cuda:
hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda())
cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size).cuda())
else:
hidden = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size))
cell = Variable(torch.zeros(2, batch_size, hp.enc_hidden_size))
hidden_cell = (hidden, cell)
_, (hidden,cell) = self.lstm(inputs.float(), hidden_cell)
# hidden is (2, batch_size, hidden_size), we want (batch_size, 2*hidden_size):
hidden_forward, hidden_backward = torch.split(hidden,1,0)
hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1)
# mu and sigma:
mu = self.fc_mu(hidden_cat)
sigma_hat = self.fc_sigma(hidden_cat)
sigma = torch.exp(sigma_hat/2.)
# N ~ N(0,1)
z_size = mu.size()
if use_cuda:
N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda())
else:
N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)))
z = mu + sigma*N
# mu and sigma_hat are needed for LKL loss
return z, mu, sigma_hat
评论列表
文章目录