check_video_id.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def build_graph(all_readers,
                input_reader,
                input_data_pattern,
                all_eval_data_patterns,
                batch_size=256):

  original_video_id, original_input, unused_labels_batch, unused_num_frames = (
      get_input_evaluation_tensors(
          input_reader,
          input_data_pattern,
          batch_size=batch_size))

  video_id_equal_tensors = []
  model_input_tensor = None
  input_distance_tensors = []
  for reader, data_pattern in zip(all_readers, all_eval_data_patterns):
    video_id, model_input_raw, labels_batch, unused_num_frames = (
        get_input_evaluation_tensors(
            reader,
            data_pattern,
            batch_size=batch_size))
    video_id_equal_tensors.append(tf.reduce_sum(tf.cast(tf.not_equal(original_video_id, video_id), dtype=tf.float32)))
    if model_input_tensor is None:
      model_input_tensor = model_input_raw
    input_distance_tensors.append(tf.reduce_mean(tf.reduce_sum(tf.square(model_input_tensor - model_input_raw), axis=1)))

  video_id_equal_tensor = tf.stack(video_id_equal_tensors)
  input_distance_tensor = tf.stack(input_distance_tensors)

  tf.add_to_collection("video_id_equal", video_id_equal_tensor)
  tf.add_to_collection("input_distance", input_distance_tensor)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号