def inception_v4(num_classes, dropout_keep_prob, weights, include_top):
'''
Creates the inception v4 network
Args:
num_classes: number of classes
dropout_keep_prob: float, the fraction to keep before final layer.
Returns:
logits: the logits outputs of the model.
'''
# Input Shape is 299 x 299 x 3 (tf) or 3 x 299 x 299 (th)
if K.image_data_format() == 'channels_first':
inputs = Input((3, 299, 299))
else:
inputs = Input((299, 299, 3))
# Make inception base
x = inception_v4_base(inputs)
# Final pooling and prediction
if include_top:
# 1 x 1 x 1536
x = AveragePooling2D((8,8), padding='valid')(x)
x = Dropout(dropout_keep_prob)(x)
x = Flatten()(x)
# 1536
x = Dense(units=num_classes, activation='softmax')(x)
model = Model(inputs, x, name='inception_v4')
# load weights
if weights == 'imagenet':
if K.image_data_format() == 'channels_first':
if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
'image data format convention '
'(`image_data_format="channels_first"`). '
'For best performance, set '
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
if include_top:
weights_path = get_file(
'inception-v4_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='9fe79d77f793fe874470d84ca6ba4a3b')
else:
weights_path = get_file(
'inception-v4_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='9296b46b5971573064d12e4669110969')
model.load_weights(weights_path, by_name=True)
return model
评论列表
文章目录