def apply_time_pooling(inputs, sequence_length, stride, pooling_avg=False):
shape = [tf.shape(inputs)[0], tf.shape(inputs)[1], inputs.get_shape()[2].value]
if pooling_avg:
inputs_ = [inputs[:, i::stride, :] for i in range(stride)]
max_len = tf.shape(inputs_[0])[1]
for k in range(1, stride):
len_ = tf.shape(inputs_[k])[1]
paddings = tf.stack([[0, 0], [0, max_len - len_], [0, 0]])
inputs_[k] = tf.pad(inputs_[k], paddings=paddings)
inputs = tf.reduce_sum(inputs_, axis=0) / len(inputs_)
else:
inputs = inputs[:, ::stride, :]
inputs = tf.reshape(inputs, tf.stack([shape[0], tf.shape(inputs)[1], shape[2]]))
sequence_length = (sequence_length + stride - 1) // stride # rounding up
return inputs, sequence_length
评论列表
文章目录