common.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def preprocess(self, inputs):
    """Perform preprocess.

    Args:
      inputs: raw input to the model.
    Returns:
      preprocessed input data.
    """
    preprocess_fn = self.get_preprocess_fn()
    assert inputs.ndim == 3 or inputs.ndim == 4, "invalid image format for preprocessing"
    if inputs.ndim == 3:
      inputs = np.expand_dims(inputs, axis=0)
    with tf.Graph().as_default() as cur_g:
      input_tensor = tf.convert_to_tensor(inputs, dtype=tf.uint8)
      all_inputs = tf.unstack(input_tensor)
      processed_inputs = []
      for cur_input in all_inputs:
        new_input = preprocess_fn(cur_input, self.net_params.input_img_height,
                                  self.net_params.input_img_width)
        processed_inputs.append(new_input)
      new_inputs = tf.stack(processed_inputs)
      with tf.Session(graph=cur_g) as sess:
        processed_inputs = sess.run(new_inputs)
    return processed_inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号