def _apply(self, X):
axes = self.axes
ndims = X.get_shape().ndims
if is_string(axes) and axes.lower() == 'auto':
if ndims == 3:
axes = (1,)
elif ndims == 4:
axes = (1, 2)
elif ndims == 5:
axes = (1, 2, 3)
X = K.upsample(X, scale=self.size, axes=axes, method=self.mode)
# ====== check output_shape ====== #
output_shape = self.output_shape
if output_shape is not None:
# do padding if necessary
paddings = [[0, 0] if i is None or o is None or i >= o else
[tf.cast(tf.ceil((o - i) / 2), 'int32'),
tf.cast(tf.floor((o - i) / 2), 'int32')]
for i, o in zip(X.get_shape().as_list(), output_shape)]
if not all(i == [0, 0] for i in paddings):
X = tf.pad(X, paddings=paddings, mode='CONSTANT')
# do slice if necessary
slices = [slice(tf.cast(tf.floor((i - o) / 2), 'int32'),
tf.cast(-tf.ceil((i - o) / 2), 'int32'), None)
if i > o else slice(None)
for i, o in zip(X.get_shape().as_list(), output_shape)]
if any(s is not slice(None) for s in slices):
X = X[slices]
K.set_shape(X, tuple([i if is_number(i) else None
for i in output_shape]))
return X
评论列表
文章目录