def build_graph(input_reader, input_data_pattern,
prediction_reader, prediction_data_pattern,
batch_size=256):
"""Creates the Tensorflow graph for evaluation.
Args:
all_readers: The data file reader. It should inherit from BaseReader.
model: The core model (e.g. logistic or neural net). It should inherit
from BaseModel.
all_data_patterns: glob path to the evaluation data files.
label_loss_fn: What kind of loss to apply to the model. It should inherit
from BaseLoss.
batch_size: How many examples to process at a time.
"""
video_ids_batch, model_inputs_batch, labels_batch, unused_num_frames = (
get_input_data_tensors(
input_reader,
input_data_pattern,
batch_size=batch_size))
video_ids_batch2, model_predictions_batch, labels_batch2, unused_num_frames2 = (
get_input_data_tensors(
prediction_reader,
prediction_data_pattern,
batch_size=batch_size))
video_ids_equal = tf.reduce_mean(tf.cast(tf.equal(video_ids_batch, video_ids_batch2), tf.float32))
labels_equal = tf.reduce_mean(tf.reduce_sum(tf.cast(tf.equal(labels_batch, labels_batch2), tf.float32), axis=1))
tf.add_to_collection("video_ids_equal", video_ids_equal)
tf.add_to_collection("labels_equal", labels_equal)
tf.add_to_collection("video_ids_batch", video_ids_batch)
tf.add_to_collection("labels_batch", tf.cast(labels_batch, tf.float32))
tf.add_to_collection("inputs_batch", model_inputs_batch)
tf.add_to_collection("predictions_batch", model_predictions_batch)
inference-combine-tfrecords-video.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录