def _attr_scope_lr(lr_type, lr_owner):
assert lr_type in ('alex', 'alex10', 'torch')
# weight (lr_mult, wd_mult); bias;
# 1, 1; 2, 0;
if lr_type == 'alex':
if lr_owner == 'weight':
return mx.AttrScope()
elif lr_owner == 'bias':
return mx.AttrScope(lr_mult='2.', wd_mult='0.')
else:
assert False
# 10, 1; 20, 0;
if lr_type == 'alex10':
if lr_owner == 'weight':
return mx.AttrScope(lr_mult='10.', wd_mult='1.')
elif lr_owner == 'bias':
return mx.AttrScope(lr_mult='20.', wd_mult='0.')
else:
assert False
# 0, 0; 0, 0;
# so apply this to both
if lr_type == 'fixed':
assert lr_owner in ('weight', 'bias')
return mx.AttrScope(lr_mult='0.', wd_mult='0.')
# 1, 1; 1, 1;
# so do nothing
return mx.AttrScope()
评论列表
文章目录