tf-keras-skeleton.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def _obtain_input_shape(input_shape, default_size, min_size, data_format,
                                include_top):
          """Internal utility to compute/validate an ImageNet model's input shape.

          Arguments:
              input_shape: either None (will return the default network input shape),
                  or a user-provided shape to be validated.
              default_size: default input width/height for the model.
              min_size: minimum input width/height accepted by the model.
              data_format: image data format to use.
              include_top: whether the model is expected to
                  be linked to a classifier via a Flatten layer.

          Returns:
              An integer shape tuple (may include None entries).

          Raises:
              ValueError: in case of invalid argument values.
          """
          if data_format == 'channels_first':
            default_shape = (3, default_size, default_size)
          else:
            default_shape = (default_size, default_size, 3)
          if include_top:
            if input_shape is not None:
              if input_shape != default_shape:
                raise ValueError('When setting`include_top=True`, '
                                 '`input_shape` should be ' + str(default_shape) + '.')
            input_shape = default_shape
          else:
            if data_format == 'channels_first':
              if input_shape is not None:
                if len(input_shape) != 3:
                  raise ValueError('`input_shape` must be a tuple of three integers.')
                if input_shape[0] != 3:
                  raise ValueError('The input must have 3 channels; got '
                                   '`input_shape=' + str(input_shape) + '`')
                if ((input_shape[1] is not None and input_shape[1] < min_size) or
                    (input_shape[2] is not None and input_shape[2] < min_size)):
                  raise ValueError('Input size must be at least ' + str(min_size) + 'x'
                                   + str(min_size) + ', got '
                                   '`input_shape=' + str(input_shape) + '`')
              else:
                input_shape = (3, None, None)
            else:
              if input_shape is not None:
                if len(input_shape) != 3:
                  raise ValueError('`input_shape` must be a tuple of three integers.')
                if input_shape[-1] != 3:
                  raise ValueError('The input must have 3 channels; got '
                                   '`input_shape=' + str(input_shape) + '`')
                if ((input_shape[0] is not None and input_shape[0] < min_size) or
                    (input_shape[1] is not None and input_shape[1] < min_size)):
                  raise ValueError('Input size must be at least ' + str(min_size) + 'x'
                                   + str(min_size) + ', got '
                                   '`input_shape=' + str(input_shape) + '`')
              else:
                input_shape = (None, None, 3)
          return input_shape
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号