utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:neural-vqa-tensorflow 作者: paarthneekhara 项目源码 文件源码
def extract_fc7_features(image_path, model_path):
    vgg_file = open(model_path)
    vgg16raw = vgg_file.read()
    vgg_file.close()

    graph_def = tf.GraphDef()
    graph_def.ParseFromString(vgg16raw)
    images = tf.placeholder("float32", [None, 224, 224, 3])
    tf.import_graph_def(graph_def, input_map={ "images": images })
    graph = tf.get_default_graph()

    sess = tf.Session()
    image_array = load_image_array(image_path)
    image_feed = np.ndarray((1,224,224,3))
    image_feed[0:,:,:] = image_array
    feed_dict  = { images : image_feed }
    fc7_tensor = graph.get_tensor_by_name("import/Relu_1:0")
    fc7_features = sess.run(fc7_tensor, feed_dict = feed_dict)
    sess.close()
    return fc7_features
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号