def reinforce_backward(self, reward, output_mask=None):
"""
If output_mask is not None, then it should be a FloatTensor of shape (N, T)
giving a multiplier to the output.
"""
assert self.multinomial_outputs is not None, 'Must call reinforce_sample first'
grad_output = []
def gen_hook(mask):
def hook(grad):
return grad * mask.contiguous().view(-1, 1).expand_as(grad)
return hook
if output_mask is not None:
for t, probs in enumerate(self.multinomial_probs):
mask = Variable(output_mask[:, t])
probs.register_hook(gen_hook(mask))
for sampled_output in self.multinomial_outputs:
sampled_output.reinforce(reward)
grad_output.append(None)
torch.autograd.backward(self.multinomial_outputs, grad_output, retain_variables=True)
评论列表
文章目录