def __call__(self, s, xs):
"""Calculate all hidden states and cell states.
Args:
s (~chainer.Variable or None): Initial (hidden & cell) states. If ``None``
is specified zero-vector is used.
xs (list of ~chianer.Variable): List of input sequences.
Each element ``xs[i]`` is a :class:`chainer.Variable` holding
a sequence.
Return:
(hy,cy): a pair of hidden and cell states at the end of the sequence,
ys: a hidden state sequence at the last layer
"""
if len(xs) > 1:
sections = np.cumsum(np.array([len(x) for x in xs[:-1]], dtype=np.int32))
xs = F.split_axis(self.embed(F.concat(xs, axis=0)), sections, axis=0)
else:
xs = [ self.embed(xs[0]) ]
if s is not None:
hy, cy, ys = self.lstm(s[0], s[1], xs)
else:
hy, cy, ys = self.lstm(None, None, xs)
return (hy,cy), ys
lstm_encoder.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录