def resolution(self, model_input_raw, num_frames, resolution):
frame_dim = len(model_input_raw.get_shape()) - 2
feature_dim = len(model_input_raw.get_shape()) - 1
max_frames = model_input_raw.get_shape().as_list()[frame_dim]
num_features = model_input_raw.get_shape().as_list()[feature_dim]
if resolution > 1:
new_max_frames = max_frames / resolution
cut_frames = new_max_frames * resolution
model_input_raw = model_input_raw[:, :cut_frames, :]
model_input_raw = tf.reshape(model_input_raw, shape=[-1,new_max_frames,resolution,num_features])
model_input_raw = tf.reduce_mean(model_input_raw, axis=2)
model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
num_frames = num_frames / resolution
else:
model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
return model_input, num_frames
multires_lstm_memory_deep_combine_chain_model.py 文件源码
python
阅读 37
收藏 0
点赞 0
评论 0
评论列表
文章目录