def FramePooling(frames, method, **unused_params):
"""Pools over the frames of a video.
Args:
frames: A tensor with shape [batch_size, num_frames, feature_size].
method: "average", "max", "attention", or "none".
Returns:
A tensor with shape [batch_size, feature_size] for average, max, or
attention pooling. A tensor with shape [batch_size*num_frames, feature_size]
for none pooling.
Raises:
ValueError: if method is other than "average", "max", "attention", or
"none".
"""
if method == "average":
return tf.reduce_mean(frames, 1)
elif method == "max":
return tf.reduce_max(frames, 1)
elif method == "none":
feature_size = frames.shape_as_list()[2]
return tf.reshape(frames, [-1, feature_size])
else:
raise ValueError("Unrecognized pooling method: %s" % method)
评论列表
文章目录