def post_decode_once(self, output, state, train=True):
lengths = state['lengths']
if self.byte:
itos = self.vocab.itos
consumed = self.xp.array([[len(itos(oi)) + 1]
for oi in output.tolist()])
lengths -= consumed
else:
lengths -= 1
flags = chainer.Variable(lengths.data >= 0, volatile=not train)
lengths = F.where(flags, lengths, self.zeros)
state['lengths'] = lengths
return state
评论列表
文章目录