cifar10_cnn_mgpu_tfqueue.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:keras_experiments 作者: avolkov1 项目源码 文件源码
def make_model(train_input, num_classes, weights_file=None):
    '''
    :param train_input: Either tensorflow Tensor or tuple/list shape. Bad style
        since the parameter can be of different types, but seems Ok here.
    :type train_input: tf.Tensor or tuple/list
    '''
    model = Sequential()
    # model.add(KL.InputLayer(input_shape=inshape[1:]))
    if isinstance(train_input, tf.Tensor):
        model.add(KL.InputLayer(input_tensor=train_input))
    else:
        model.add(KL.InputLayer(input_shape=train_input))
    model.add(KL.Conv2D(32, (3, 3), padding='same'))
    model.add(KL.Activation('relu'))
    model.add(KL.Conv2D(32, (3, 3)))
    model.add(KL.Activation('relu'))
    model.add(KL.MaxPooling2D(pool_size=(2, 2)))
    model.add(KL.Dropout(0.25))

    model.add(KL.Conv2D(64, (3, 3), padding='same'))
    model.add(KL.Activation('relu'))
    model.add(KL.Conv2D(64, (3, 3)))
    model.add(KL.Activation('relu'))
    model.add(KL.MaxPooling2D(pool_size=(2, 2)))
    model.add(KL.Dropout(0.25))

    model.add(KL.Flatten())
    model.add(KL.Dense(512))
    model.add(KL.Activation('relu'))
    model.add(KL.Dropout(0.5))
    model.add(KL.Dense(num_classes))
    model.add(KL.Activation('softmax'))

    if weights_file is not None and os.path.exists(weights_file):
        model.load_weights(weights_file)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号