def reset_state(chain):
assert isinstance(chain, (chainer.Chain, chainer.ChainList))
for l in chain.children():
if isinstance(l, chainer.links.LSTM):
l.reset_state()
elif isinstance(l, Recurrent):
l.reset_state()
elif isinstance(l, (chainer.Chain, chainer.ChainList)):
reset_state(l)
评论列表
文章目录