def make_model(train_input, num_classes, weights_file=None):
'''
:param train_input: Either tensorflow Tensor or tuple/list shape. Bad style
since the parameter can be of different types, but seems Ok here.
:type train_input: tf.Tensor or tuple/list
'''
model = Sequential()
# model.add(KL.InputLayer(input_shape=inshape[1:]))
if isinstance(train_input, tf.Tensor):
model.add(KL.InputLayer(input_tensor=train_input))
else:
model.add(KL.InputLayer(input_shape=train_input))
model.add(KL.Conv2D(32, (3, 3), padding='same'))
model.add(KL.Activation('relu'))
model.add(KL.Conv2D(32, (3, 3)))
model.add(KL.Activation('relu'))
model.add(KL.MaxPooling2D(pool_size=(2, 2)))
model.add(KL.Dropout(0.25))
model.add(KL.Conv2D(64, (3, 3), padding='same'))
model.add(KL.Activation('relu'))
model.add(KL.Conv2D(64, (3, 3)))
model.add(KL.Activation('relu'))
model.add(KL.MaxPooling2D(pool_size=(2, 2)))
model.add(KL.Dropout(0.25))
model.add(KL.Flatten())
model.add(KL.Dense(512))
model.add(KL.Activation('relu'))
model.add(KL.Dropout(0.5))
model.add(KL.Dense(num_classes))
model.add(KL.Activation('softmax'))
if weights_file is not None and os.path.exists(weights_file):
model.load_weights(weights_file)
return model
cifar10_cnn_mgpu_tfqueue.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录