sampling.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def _apply(self, X):
    axes = self.axes
    ndims = X.get_shape().ndims
    if is_string(axes) and axes.lower() == 'auto':
      if ndims == 3:
        axes = (1,)
      elif ndims == 4:
        axes = (1, 2)
      elif ndims == 5:
        axes = (1, 2, 3)
    X = K.upsample(X, scale=self.size, axes=axes, method=self.mode)
    # ====== check output_shape ====== #
    output_shape = self.output_shape
    if output_shape is not None:
      # do padding if necessary
      paddings = [[0, 0] if i is None or o is None or i >= o else
                  [tf.cast(tf.ceil((o - i) / 2), 'int32'),
                   tf.cast(tf.floor((o - i) / 2), 'int32')]
                  for i, o in zip(X.get_shape().as_list(), output_shape)]
      if not all(i == [0, 0] for i in paddings):
        X = tf.pad(X, paddings=paddings, mode='CONSTANT')
      # do slice if necessary
      slices = [slice(tf.cast(tf.floor((i - o) / 2), 'int32'),
                      tf.cast(-tf.ceil((i - o) / 2), 'int32'), None)
                if i > o else slice(None)
                for i, o in zip(X.get_shape().as_list(), output_shape)]
      if any(s is not slice(None) for s in slices):
        X = X[slices]
      K.set_shape(X, tuple([i if is_number(i) else None
                            for i in output_shape]))
    return X
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号