def update(self, s, i):
"""Update decoder state
Args:
s (any): Current (hidden, cell) states. If ``None`` is specified
zero-vector is used.
i (int): input label.
Return:
(~chainer.Variable) updated decoder state
"""
if cuda.get_device_from_array(s[0].data).id >= 0:
xp = cuda.cupy
else:
xp = np
v = chainer.Variable(xp.array([i],dtype=np.int32))
x = self.embed(v)
if s is not None:
hy, cy, dy = self.lstm(s[0], s[1], [x])
else:
hy, cy, dy = self.lstm(None, None, [x])
return hy, cy, dy
lstm_decoder.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录