def apply(self, data, copy=False):
if copy:
data = np.copy(data)
data_shape = data.shape
if len(data.shape) > 2:
data = data.reshape(data.shape[0], np.product(data.shape[1:]))
assert len(data.shape) == 2, 'Contrast norm on flattened data'
# assert np.min(data) >= 0.
# assert np.max(data) <= 1.
data -= data.mean(axis=1)[:, np.newaxis]
norms = np.sqrt(np.sum(data ** 2, axis=1)) / self.scale
norms[norms < self.epsilon] = self.epsilon
data /= norms[:, np.newaxis]
if data_shape != data.shape:
data = data.reshape(data_shape)
return data
评论列表
文章目录