def update_state(self, step, src_seq, enc_outputs, un_dones):
input_pos = torch.arange(1, step+1).unsqueeze(0)
input_pos = input_pos.repeat(un_dones, 1)
input_pos = Variable(input_pos.long(), volatile=True)
src_seq_beam = Variable(src_seq.data.repeat(un_dones, 1))
enc_outputs_beam = [Variable(enc_output.data.repeat(un_dones, 1, 1)) for enc_output in enc_outputs]
return input_pos, src_seq_beam, enc_outputs_beam
评论列表
文章目录