def avg(self, model_input_raw, num_frames, mask):
max_frames = model_input_raw.get_shape().as_list()[1]
num_frames_matrix = tf.maximum(tf.cast(
tf.expand_dims(num_frames, axis=1),
dtype=tf.float32), 1.0)
mean_matrix = mask / num_frames_matrix
mean_input = tf.einsum("ijk,ij->ik", model_input_raw, mean_matrix)
mean_input_tile = tf.tile(tf.expand_dims(mean_input, axis=1), multiples=[1,max_frames,1])
return mean_input_tile
评论列表
文章目录