def create_image_model_squeezenet(images_shape, repeat_count):
print('Using SqueezeNet')
inputs = Input(shape=images_shape)
visual_model = get_squeezenet(1000, dim_ordering='tf', include_top=False)
# visual_model.load_weights('squeezenet/model/squeezenet_weights_tf_dim_ordering_tf_kernels.h5')
x = visual_model(inputs)
x = GlobalMaxPooling2D()(x)
x = RepeatVector(repeat_count)(x)
return Model(inputs, x, 'image_model')
评论列表
文章目录