seq2seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:clevr-iep 作者: facebookresearch 项目源码 文件源码
def reinforce_sample(self, x, max_length=30, temperature=1.0, argmax=False):
    N, T = x.size(0), max_length
    encoded = self.encoder(x)
    y = torch.LongTensor(N, T).fill_(self.NULL)
    done = torch.ByteTensor(N).fill_(0)
    cur_input = Variable(x.data.new(N, 1).fill_(self.START))
    h, c = None, None
    self.multinomial_outputs = []
    self.multinomial_probs = []
    for t in range(T):
      # logprobs is N x 1 x V
      logprobs, h, c = self.decoder(encoded, cur_input, h0=h, c0=c)
      logprobs = logprobs / temperature
      probs = F.softmax(logprobs.view(N, -1)) # Now N x V
      if argmax:
        _, cur_output = probs.max(1)
      else:
        cur_output = probs.multinomial() # Now N x 1
      self.multinomial_outputs.append(cur_output)
      self.multinomial_probs.append(probs)
      cur_output_data = cur_output.data.cpu()
      not_done = logical_not(done)
      y[:, t][not_done] = cur_output_data[not_done]
      done = logical_or(done, cur_output_data.cpu() == self.END)
      cur_input = cur_output
      if done.sum() == N:
        break
    return Variable(y.type_as(x.data))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号