def __init__(self, incoming, num_styles=None, epsilon=1e-4,
beta=Constant(0), gamma=Constant(1), **kwargs):
super(InstanceNormLayer, self).__init__(incoming, **kwargs)
self.axes = (2, 3)
self.epsilon = epsilon
if num_styles == None:
shape = (self.input_shape[1],)
else:
shape = (num_styles, self.input_shape[1])
if beta is None:
self.beta = None
else:
self.beta = self.add_param(beta, shape, 'beta',
trainable=True, regularizable=False)
if gamma is None:
self.gamma = None
else:
self.gamma = self.add_param(gamma, shape, 'gamma',
trainable=True, regularizable=True)
评论列表
文章目录