def get_output_for(self, input, **kwargs):
# [batch, n-step, num_input_channels, input_length]
input_shape = input.shape
batch_size = input_shape[0]
time_steps = input_shape[1]
# [batch * n-step, num_input_channels, input_length]
input_shape = (batch_size * time_steps, input_shape[2], input_shape[3])
output = super(PoolTimeStep1DLayer, self).get_output_for(T.reshape(input, input_shape), **kwargs)
# [batch * n-step, num_input_channels, pool_length]
output_shape = output.shape
# [batch, n-step, num_input_channels, pool_length]
output_shape = (batch_size, time_steps, output_shape[1], output_shape[2])
return T.reshape(output, output_shape)
评论列表
文章目录