def __call__(self, *inputs):
outputs = []
for idx, _input in enumerate(inputs):
channel_means = _input.mean(1).mean(2)
channel_means = channel_means.expand_as(_input)
_input = th.clamp((_input - channel_means) * self.value + channel_means,0,1)
outputs.append(_input)
return outputs if idx > 1 else outputs[0]
评论列表
文章目录