main.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:grad-cam.tensorflow 作者: Ankush96 项目源码 文件源码
def main(_):
    x, img = load_image(FLAGS.input)

    sess = tf.Session()

    print("\nLoading Vgg")
    imgs = tf.placeholder(tf.float32, [None, 224, 224, 3])
    vgg = vgg16(imgs, 'vgg16_weights.npz', sess)

    print("\nFeedforwarding")
    prob = sess.run(vgg.probs, feed_dict={vgg.imgs: x})[0]
    preds = (np.argsort(prob)[::-1])[0:5]
    print('\nTop 5 classes are')
    for p in preds:
        print(class_names[p], prob[p])

    # Target class
    predicted_class = preds[0]
    # Target layer for visualization
    layer_name = FLAGS.layer_name
    # Number of output classes of model being used
    nb_classes = 1000

    cam3 = grad_cam(x, vgg, sess, predicted_class, layer_name, nb_classes)

    img = img.astype(float)
    img /= img.max()

    # Superimposing the visualization with the image.
    new_img = img+3*cam3
    new_img /= new_img.max()

    # Display and save
    io.imshow(new_img)
    plt.show()
    io.imsave(FLAGS.output, new_img)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号