multires_lstm_memory_deep_combine_chain_model.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def resolution(self, model_input_raw, num_frames, resolution):
    frame_dim = len(model_input_raw.get_shape()) - 2
    feature_dim = len(model_input_raw.get_shape()) - 1
    max_frames = model_input_raw.get_shape().as_list()[frame_dim]
    num_features = model_input_raw.get_shape().as_list()[feature_dim]
    if resolution > 1:
      new_max_frames = max_frames / resolution
      cut_frames = new_max_frames * resolution
      model_input_raw = model_input_raw[:, :cut_frames, :]
      model_input_raw = tf.reshape(model_input_raw, shape=[-1,new_max_frames,resolution,num_features])
      model_input_raw = tf.reduce_mean(model_input_raw, axis=2)

      model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
      num_frames = num_frames / resolution
    else:
      model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
    return model_input, num_frames
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号