def get_state(chain):
assert isinstance(chain, (chainer.Chain, chainer.ChainList))
state = []
for l in chain.children():
if isinstance(l, chainer.links.LSTM):
state.append((l.c, l.h))
elif isinstance(l, Recurrent):
state.append(l.get_state())
elif isinstance(l, (chainer.Chain, chainer.ChainList)):
state.append(get_state(l))
else:
state.append(None)
return state
评论列表
文章目录