def __call__(self, x):
if not hasattr(self, 'encoding') or self.encoding is None:
self.batch_size = x.shape[0]
self.init()
dims = len(x.shape) - 1
f, z, o = F.split_axis(self.pre(x), 3, axis=dims)
f = F.sigmoid(f)
z = (1 - f) * F.tanh(z)
o = F.sigmoid(o)
if dims == 2:
self.c = strnn(f, z, self.c[:self.batch_size])
else:
self.c = f * self.c + z
if self.attention:
context = attention_sum(self.encoding, self.c)
self.h = o * self.o(F.concat((self.c, context), axis=dims))
else:
self.h = self.c * o
self.x = x
return self.h
评论列表
文章目录