def forward(self, prob, target, reward):
"""
Args:
prob: (N, C), torch Variable
target : (N, ), torch Variable
reward : (N, ), torch Variable
"""
N = target.size(0)
C = prob.size(1)
one_hot = torch.zeros((N, C))
if prob.is_cuda:
one_hot = one_hot.cuda()
one_hot.scatter_(1, target.data.view((-1,1)), 1)
one_hot = one_hot.type(torch.ByteTensor)
one_hot = Variable(one_hot)
if prob.is_cuda:
one_hot = one_hot.cuda()
loss = torch.masked_select(prob, one_hot)
loss = loss * reward
loss = -torch.sum(loss)
return loss
评论列表
文章目录