def duplicate_model_with_quant(model, bits, overflow_rate=0.0, counter=10, type='linear'):
"""assume that original model has at least a nn.Sequential"""
assert type in ['linear', 'minmax', 'log', 'tanh']
if isinstance(model, nn.Sequential):
l = OrderedDict()
for k, v in model._modules.items():
if isinstance(v, (nn.Conv2d, nn.Linear, nn.BatchNorm1d, nn.BatchNorm2d, nn.AvgPool2d)):
l[k] = v
if type == 'linear':
quant_layer = LinearQuant('{}_quant'.format(k), bits=bits, overflow_rate=overflow_rate, counter=counter)
elif type == 'log':
# quant_layer = LogQuant('{}_quant'.format(k), bits=bits, overflow_rate=overflow_rate, counter=counter)
quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=log_minmax_quantize)
elif type == 'minmax':
quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=min_max_quantize)
else:
quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=tanh_quantize)
l['{}_{}_quant'.format(k, type)] = quant_layer
else:
l[k] = duplicate_model_with_quant(v, bits, overflow_rate, counter, type)
m = nn.Sequential(l)
return m
else:
for k, v in model._modules.items():
model._modules[k] = duplicate_model_with_quant(v, bits, overflow_rate, counter, type)
return model
评论列表
文章目录