def batch_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs):
"""
adapted from https://gist.github.com/f0k/f1a6bd3c8585c400c190
"""
nonlinearity = getattr(layer, 'nonlinearity', None)
if nonlinearity is not None:
layer.nonlinearity = lasagne.nonlinearities.identity
else:
nonlinearity = lasagne.nonlinearities.identity
if hasattr(layer, 'b'):
del layer.params[layer.b]
layer.b = None
return BatchNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs)
评论列表
文章目录