def load_images(image_files, vgg, pl_images):
dataset = np.ndarray(shape=(len(image_files), feat_len), dtype=np.float32)
image_index = 0
for image in image_files:
try:
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = skimage.io.imread(image)
image_data = image_data / 255.0
batch = np.ndarray(shape=(1, image_data.shape[0], image_data.shape[1], image_data.shape[2]), dtype=np.float32)
batch[0, :, :, :] = image_data
feed_dict = {pl_images: batch}
with tf.Session() as sess:
with tf.device("/cpu:0"):
feat = sess.run(vgg.conv5_4, feed_dict=feed_dict)
feat.resize(feat_len,refcheck=False)
dataset[image_index, :] = feat
image_index += 1
except IOError as e:
print('Could not read:', image, ':', e, '- it\'s ok, skipping.')
dataset = dataset[0:image_index, :]
print('Full dataset tensor:', dataset.shape)
return dataset
评论列表
文章目录