def __call__(self, xs, train=True):
batch = len(xs)
if self.hx is None:
xp = self.xp
self.hx = Variable(
xp.zeros((self.n_layers, batch, self.state_size), dtype=xs[0].dtype),
volatile='auto')
if self.cx is None:
xp = self.xp
self.cx = Variable(
xp.zeros((self.n_layers, batch, self.state_size), dtype=xs[0].dtype),
volatile='auto')
hy, cy, ys = super(NStepLSTM, self).__call__(self.hx, self.cx, xs, train)
self.hx, self.cx = hy, cy
return ys
评论列表
文章目录