def get_output_for(self, input, **kwargs):
axis = self.axis
ndims = input.get_shape().ndims
if axis < 0:
axis += ndims
if isinstance(self.slice, int) and self.slice < 0:
return tf.reverse(input, [False] * self.axis + [True] + [False] * (ndims - axis - 1))[
(slice(None),) * axis + (-1 - self.slice,) + (slice(None),) * (ndims - axis - 1)
]
# import ipdb; ipdb.set_trace()
return input[(slice(None),) * axis + (self.slice,) + (slice(None),) * (ndims - axis - 1)]
评论列表
文章目录