def main():
x = tf.placeholder(tf.float32, [None, 224, 224, 3])
network, probs = build_vgg(x)
# network2, probs2 = build_vgg(x)
sess = tf.InteractiveSession()
tl.layers.initialize_global_variables(sess)
network.print_params()
network.print_layers()
npz = np.load('vgg16_weights.npz')
params = []
for val in sorted( npz.items() ):
print(" Loading %s" % str(val[1].shape))
params.append(val[1])
tl.files.assign_params(sess, params, network)
img1 = imread('laska.png', mode='RGB')
img1 = imresize(img1, (224, 224))
prob = sess.run(probs, feed_dict={x: [img1]})[0]
print(prob)
评论列表
文章目录