def forward_one_step(self, X, skip_mask=None): pad = self._kernel_size - 1 WX = self.W(X)[:, :, -pad-1, None] return self.pool(functions.split_axis(WX, self.num_split, axis=1), skip_mask=skip_mask)