preprocessing_factory.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ActionVLAD 作者: rohitgirdhar 项目源码 文件源码
def get_preprocessing(name, is_training=False):
  """Returns preprocessing_fn(image, height, width, **kwargs).

  Args:
    name: The name of the preprocessing function.
    is_training: `True` if the model is being used for training and `False`
      otherwise.

  Returns:
    preprocessing_fn: A function that preprocessing a single image (pre-batch).
      It has the following signature:
        image = preprocessing_fn(image, output_height, output_width, ...).

  Raises:
    ValueError: If Preprocessing `name` is not recognized.
  """
  preprocessing_fn_map = {
      'vgg_ucf': vgg_ucf_preprocessing,
  }

  if name not in preprocessing_fn_map:
    raise ValueError('Preprocessing name [%s] was not recognized' % name)

  def preprocessing_fn(image, output_height, output_width, **kwargs):
    with tf.variable_scope('preprocess_image'):
      if len(image.get_shape()) == 3:
        return preprocessing_fn_map[name].preprocess_image(
            image, output_height, output_width, is_training=is_training, **kwargs)
      elif len(image.get_shape()) == 4:
        # preprocess all the images in one set in the same way by concat-ing
        # them in channels
        nImgs = image.get_shape().as_list()[0]
        final_img_concat = preprocessing_fn_map[name].preprocess_image(
            tf.concat(2, tf.unpack(image)),
            output_height, output_width, is_training=is_training, **kwargs)
        return tf.concat(0, tf.split(3, nImgs, final_img_concat))
      else:
        print('Incorrect dims image!')

  return preprocessing_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号