def forward(self, *inputs):
dim = inputs[0].dim()
assert_(dim in [4, 5],
'Input tensors must either be 4 or 5 '
'dimensional, but inputs[0] is {}D.'.format(dim),
ShapeError)
# Get resize function
spatial_dim = {4: 2, 5: 3}[dim]
resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode,
spatial_dim))
target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim)
# Do the resizing
resized_inputs = []
for input_num, input in enumerate(inputs):
# Make sure the dim checks out
assert_(input.dim() == dim,
"Expected inputs[{}] to be a {}D tensor, got a {}D "
"tensor instead.".format(input_num, dim, input.dim()),
ShapeError)
resized_inputs.append(resize_function(input, target_size))
# Concatenate along the channel axis
concatenated = torch.cat(tuple(resized_inputs), 1)
# Done
return concatenated
评论列表
文章目录