symbol.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:ademxapp 作者: itijyou 项目源码 文件源码
def bn(data, name, eps=1.001e-5, fix_gamma=False, use_global_stats=None):
    if use_global_stats is None:
        use_global_stats = cfg.get('bn_use_global_stats', False)

    if fix_gamma:
        with mx.AttrScope(lr_mult='0.', wd_mult='0.'):
            gamma = mx.sym.Variable('{}_gamma'.format(name))
            beta = mx.sym.Variable('{}_beta'.format(name))
        return mx.sym.BatchNorm(data=data, gamma=gamma, beta=beta, name=name,
                                eps=eps,
                                fix_gamma=True,
                                use_global_stats=use_global_stats)
    else:
        lr_type = cfg.get('lr_type', 'torch')
        with _attr_scope_lr(lr_type, 'weight'):
            gamma = mx.sym.Variable('{}_gamma'.format(name))
        with _attr_scope_lr(lr_type, 'bias'):
            beta = mx.sym.Variable('{}_beta'.format(name))
        return mx.sym.BatchNorm(data=data, gamma=gamma, beta=beta, name=name,
                                eps=eps,
                                fix_gamma=False,
                                use_global_stats=use_global_stats)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号