def forward(self, fc, seq_len):
'''
fc: [bsz, max_len, fc_size], has already passed through sigmoid layer
seq_len: [bsz]
'''
loss = Variable(torch.zeros(1), requires_grad = True)
if fc.is_cuda:
self.age_dis_trans.cuda()
loss = loss.cuda()
bsz, max_len = fc.size()[0:2]
fc = fc.view(bsz * max_len, -1)
log_prob = F.log_softmax(self.age_dis_trans(fc)).view(bsz, max_len, -1)
prob = log_prob.detach().exp()
seq_len = seq_len.long()
for i in range(bsz):
l = seq_len.data[i]-1
loss = loss + F.kl_div(log_prob[i,0:l], prob[i,1:(l+1)], False)/l
loss = loss / bsz
return loss
评论列表
文章目录