def crosschannelnormalization(alpha=1e-4, k=2, beta=0.75, n=5, **kwargs):
"""
This is the function used for cross channel normalization in the original
Alexnet
"""
def f(X):
b, ch, r, c = X.shape
half = n // 2
square = K.square(X)
extra_channels = K.spatial_2d_padding(K.permute_dimensions(square, (0, 2, 3, 1))
, (0, half))
extra_channels = K.permute_dimensions(extra_channels, (0, 3, 1, 2))
scale = k
for i in range(n):
scale += alpha * extra_channels[:, i:i + ch, :, :]
scale = scale ** beta
return X / scale
return Lambda(f, output_shape=lambda input_shape: input_shape, **kwargs)
评论列表
文章目录