def generate(model, xs):
model.reset_state()
tags = model([Variable(
np.array([x], dtype=np.int32)
) for x in xs])
buf = bytearray()
for x, (y, zs) in zip(xs, tags):
buf.append(x)
if cf.sigmoid(y).data[0, 0] > 0.5:
yield (
buf.decode('utf-8', 'replace'),
tuple(
cf.softmax(z).data.argmax(1)[0]
for z in zs
)
)
buf = bytearray()
评论列表
文章目录