def Get_Pre_Trained_Weights(input_vars,name):
with open("vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder(tf.float32,shape = (None, 64, 64, 3),name=name)
tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
#batch = np.reshape(input_vars,(-1, 224, 224, 3))
n_timewin = 7
convnets = []
for i in xrange(n_timewin):
feed_dict = { images:input_vars[:,i,:,:,:] }
pool_tensor = graph.get_tensor_by_name("import/pool5:0")
pool_tensor = sess.run(pool_tensor, feed_dict=feed_dict)
convnets.append(tf.contrib.layers.flatten(pool_tensor))
convpool = tf.pack(convnets, axis = 1)
return convpool
评论列表
文章目录