extract_cnn_vgg16_keras.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:flask-keras-cnn-image-retrieval 作者: willard-yuan 项目源码 文件源码
def extract_feat(img_path):
    # weights: 'imagenet'
    # pooling: 'max' or 'avg'
    # input_shape: (width, height, 3), width and height should >= 48

    input_shape = (224, 224, 3)
    model = VGG16(weights = 'imagenet', input_shape = (input_shape[0], input_shape[1], input_shape[2]), pooling = 'max', include_top = False)

    img = image.load_img(img_path, target_size=(input_shape[0], input_shape[1]))
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    feat = model.predict(img)
    norm_feat = feat[0]/LA.norm(feat[0])
    return norm_feat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号