def set_state(chain, state):
assert isinstance(chain, (chainer.Chain, chainer.ChainList))
for l, s in zip(chain.children(), state):
if isinstance(l, chainer.links.LSTM):
c, h = s
# LSTM.set_state doesn't accept None state
if c is not None:
l.set_state(c, h)
elif isinstance(l, Recurrent):
l.set_state(s)
elif isinstance(l, (chainer.Chain, chainer.ChainList)):
set_state(l, s)
else:
assert s is None
评论列表
文章目录