inference_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session() as sess:
    video_id_batch, video_batch, video_label_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
        init_op_list = []
        for variable in list(variables):
            if "train_input" in variable.name:
                init_op_list.append(tf.assign(variable, 1))
                variables.remove(variable)
        init_op_list.append(tf.variables_initializer(variables))
        return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()
    #out_file.write("VideoId,LabelConfidencePairs\n")
    filenum = 0
    video_id = []
    try:
      while not coord.should_stop():
          video_id_batch_val = sess.run(video_id_batch)
          video_id.extend(video_id_batch_val)
          now = time.time()
          num_examples_processed += len(video_id_batch_val)

          if num_examples_processed>=FLAGS.file_size:
              if num_examples_processed>FLAGS.file_size:
                  print("Wrong!", num_examples_processed)
              else:
                  print(num_examples_processed)
              #logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
              """
              thefile = open('inference_test/video_id_test_'+str(filenum)+'.out', 'w')
              for item in video_id:
                item = ''.join(str(e) for e in item)
                thefile.write("%s\n" % item)"""
              filenum += 1
              video_id = []
              num_examples_processed = 0

    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()
        if num_examples_processed<FLAGS.file_size:
            print(num_examples_processed)
            thefile = open('inference_test/video_id_test_'+str(filenum)+'.out', 'w')
            for item in video_id:
                item = ''.join(str(e) for e in item)
                thefile.write("%s\n" % item)

    coord.join(threads)
    sess.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号