vgg_loss.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:vae-celebA 作者: yzwxx 项目源码 文件源码
def main():
    x = tf.placeholder(tf.float32, [None, 224, 224, 3])
    network, probs = build_vgg(x)
    # network2, probs2 = build_vgg(x)
    sess = tf.InteractiveSession()
    tl.layers.initialize_global_variables(sess)
    network.print_params()
    network.print_layers()


    npz = np.load('vgg16_weights.npz')
    params = []
    for val in sorted( npz.items() ):
        print("  Loading %s" % str(val[1].shape))
        params.append(val[1])
    tl.files.assign_params(sess, params, network)

    img1 = imread('laska.png', mode='RGB') 
    img1 = imresize(img1, (224, 224))

    prob = sess.run(probs, feed_dict={x: [img1]})[0]
    print(prob)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号