def __call__(self, X, skip_mask=None):
# remove right paddings
# e.g.
# kernel_size = 3
# pad = 2
# input sequence with paddings:
# [0, 0, x1, x2, x3, 0, 0]
# |< t1 >|
# |< t2 >|
# |< t3 >|
pad = self._kernel_size - 1
WX = self.W(X)[..., :-pad]
return self.pool(functions.split_axis(WX, self.num_split, axis=1), skip_mask=skip_mask)
评论列表
文章目录