def test_decode(self, start, eos, limit):
output = []
y = chainer.Variable(np.array([[start]], dtype=np.int32))
for i in range(limit):
decode0 = self.output_embed(y)
decode1 = self.decode1(decode0)
decode2 = self.decode2(decode1)
z = self.output(decode2)
prob = F.softmax(z)
index = np.argmax(cuda.to_cpu(prob.data))
if index == eos:
break
output.append(index)
y = chainer.Variable(np.array([index], dtype=np.int32))
return output
评论列表
文章目录