def _setup_batchnorm(self, layer):
# Get layer parameters.
blobs = layer.blobs
param = layer.batch_norm_param
use_global_stats = param.use_global_stats
decay = param.moving_average_fraction
eps = param.eps
size = int(blobs[0].shape.dim[0]) # Get channel dim from mean blob.
# Make BatchNormalization link.
func = links.BatchNormalization(size, decay=decay, eps=eps,
use_gamma=False, use_beta=False)
func.avg_mean.ravel()[:] = blobs[0].data
func.avg_var.ravel()[:] = blobs[1].data
self.add_link(layer.name, func)
# Add layer.
fwd = _SingleArgumentFunction(
_CallChildLink(self, layer.name),
test=use_global_stats, finetune=False)
self.forwards[layer.name] = fwd
self._add_layer(layer)
评论列表
文章目录