def bn(data, name, eps=1.001e-5, fix_gamma=False, use_global_stats=None):
if use_global_stats is None:
use_global_stats = cfg.get('bn_use_global_stats', False)
if fix_gamma:
with mx.AttrScope(lr_mult='0.', wd_mult='0.'):
gamma = mx.sym.Variable('{}_gamma'.format(name))
beta = mx.sym.Variable('{}_beta'.format(name))
return mx.sym.BatchNorm(data=data, gamma=gamma, beta=beta, name=name,
eps=eps,
fix_gamma=True,
use_global_stats=use_global_stats)
else:
lr_type = cfg.get('lr_type', 'torch')
with _attr_scope_lr(lr_type, 'weight'):
gamma = mx.sym.Variable('{}_gamma'.format(name))
with _attr_scope_lr(lr_type, 'bias'):
beta = mx.sym.Variable('{}_beta'.format(name))
return mx.sym.BatchNorm(data=data, gamma=gamma, beta=beta, name=name,
eps=eps,
fix_gamma=False,
use_global_stats=use_global_stats)
评论列表
文章目录