train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tf-tutorial 作者: zchen0211 项目源码 文件源码
def visualize_input(model):
  sess = tf.Session()
  sess.run(tf.global_variables_initializer())
  tf.train.start_queue_runners(sess=sess)

  batch_img, batch_cap = sess.run([model.images, model.input_seqs])
  # show first img
  batch_img = batch_img[0,:]
  batch_img = (batch_img + 1.) / 2.

  # show caption
  fid = open('/media/DATA/MS-COCO/word_counts.txt')
  raw_words = fid.readlines()
  words = []
  for raw_word in raw_words:
    word, _ = raw_word.split()
    words.append(word)
  batch_cap = batch_cap[0]
  sentence = []
  for tmp_id in batch_cap:
    sentence.append(words[int(tmp_id)])
  print(sentence)
  plt.imshow(batch_img)
  plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号