ResNet.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Papers2Code 作者: rainer85ah 项目源码 文件源码
def build_resnet(self, input_shape=None, num_outputs=1000, layers=None, weights_path=None):
        """
        Args:
            input_shape: The input shape in the form (nb_rows, nb_cols, nb_channels) TensorFlow Format!!
            num_outputs: The number of outputs at final softmax layer
            layers: Number of layers for every network 50, 101, 152
            weights_path: URL to the weights of a pre-trained model.
            optimizer: An optimizer to compile the model, if None sgd+momentum by default.
        Returns:
            A compile Keras model.
        """
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple like (nb_rows, nb_cols, nb_channels)")

        input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=197,
                                          data_format=K.image_data_format(), include_top=True)
        img_input = Input(shape=input_shape)
        x = ZeroPadding2D((3, 3))(img_input)
        x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
        x = BatchNormalization(axis=3, name='bn_conv1')(x)
        x = Activation('relu')(x)
        x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='pool1')(x)

        nb_filters = 64
        stage = 2
        for e in layers:
            for i in range(e):
                if i == 0:
                    x = block_with_shortcut(x, nb_filters, stage=stage, block='a', strides=2 if stage >= 3 else 1)
                else:
                    x = block_without_shortcut(x, nb_filters, stage=stage, block='b', index=i)
            stage += 1
            nb_filters *= 2

        x = AveragePooling2D((7, 7), strides=(1, 1), name='avg_pool')(x)
        x = Flatten()(x)
        x = Dense(units=num_outputs, activation='softmax', name='fc1000')(x)
        self.model = Model(inputs=img_input, outputs=x, name='ResNet Model')

        if weights_path is not None:
            model.load_weights(weights_path)

        return self.model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号