seq2seq.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:clevr-iep 作者: facebookresearch 项目源码 文件源码
def reinforce_backward(self, reward, output_mask=None):
    """
    If output_mask is not None, then it should be a FloatTensor of shape (N, T)
    giving a multiplier to the output.
    """
    assert self.multinomial_outputs is not None, 'Must call reinforce_sample first'
    grad_output = []

    def gen_hook(mask):
      def hook(grad):
        return grad * mask.contiguous().view(-1, 1).expand_as(grad)
      return hook

    if output_mask is not None:
      for t, probs in enumerate(self.multinomial_probs):
        mask = Variable(output_mask[:, t])
        probs.register_hook(gen_hook(mask))

    for sampled_output in self.multinomial_outputs:
      sampled_output.reinforce(reward)
      grad_output.append(None)
    torch.autograd.backward(self.multinomial_outputs, grad_output, retain_variables=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号