inference-combine-tfrecords-video.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def build_graph(input_reader, input_data_pattern,
                prediction_reader, prediction_data_pattern,
                batch_size=256):
  """Creates the Tensorflow graph for evaluation.

  Args:
    all_readers: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    all_data_patterns: glob path to the evaluation data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
  """

  video_ids_batch, model_inputs_batch, labels_batch, unused_num_frames = (
      get_input_data_tensors(
          input_reader,
          input_data_pattern,
          batch_size=batch_size))
  video_ids_batch2, model_predictions_batch, labels_batch2, unused_num_frames2 = (
      get_input_data_tensors(
          prediction_reader,
          prediction_data_pattern,
          batch_size=batch_size))

  video_ids_equal = tf.reduce_mean(tf.cast(tf.equal(video_ids_batch, video_ids_batch2), tf.float32))
  labels_equal = tf.reduce_mean(tf.reduce_sum(tf.cast(tf.equal(labels_batch, labels_batch2), tf.float32), axis=1))

  tf.add_to_collection("video_ids_equal", video_ids_equal)
  tf.add_to_collection("labels_equal", labels_equal)
  tf.add_to_collection("video_ids_batch", video_ids_batch)
  tf.add_to_collection("labels_batch", tf.cast(labels_batch, tf.float32))
  tf.add_to_collection("inputs_batch", model_inputs_batch)
  tf.add_to_collection("predictions_batch", model_predictions_batch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号