def pre(self, x):
dims = len(x.shape) - 1
if self.kernel_size == 1:
ret = self.W(x)
elif self.kernel_size == 2:
if dims == 2:
xprev = Variable(
self.xp.zeros((self.batch_size, 1, self.in_size),
dtype=np.float32), volatile='AUTO')
xtminus1 = F.concat((xprev, x[:, :-1, :]), axis=1)
else:
xtminus1 = self.x
ret = self.W(x) + self.V(xtminus1)
else:
ret = F.swapaxes(self.conv(
F.swapaxes(x, 1, 2))[:, :, :x.shape[2]], 1, 2)
if not self.attention:
return ret
if dims == 1:
enc = self.encoding[:, -1, :]
else:
enc = self.encoding[:, -1:, :]
return sum(F.broadcast(self.U(enc), ret))
评论列表
文章目录