data_generation.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:IM2TXT 作者: aayushP 项目源码 文件源码
def load_images(image_files, vgg, pl_images):
  dataset = np.ndarray(shape=(len(image_files), feat_len), dtype=np.float32)
  image_index = 0
  for image in image_files:
    try:
        if not tf.gfile.Exists(image):
            tf.logging.fatal('File does not exist %s', image)
        image_data = skimage.io.imread(image)
        image_data = image_data / 255.0
        batch = np.ndarray(shape=(1, image_data.shape[0], image_data.shape[1], image_data.shape[2]), dtype=np.float32)
        batch[0, :, :, :] = image_data
        feed_dict = {pl_images: batch}

        with tf.Session() as sess:
            with tf.device("/cpu:0"):
                feat = sess.run(vgg.conv5_4, feed_dict=feed_dict)

        feat.resize(feat_len,refcheck=False)
        dataset[image_index, :] = feat
        image_index += 1

    except IOError as e:
      print('Could not read:', image, ':', e, '- it\'s ok, skipping.')

  dataset = dataset[0:image_index, :]

  print('Full dataset tensor:', dataset.shape)
  return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号