def define_teacher(params_file):
""" Defines student resnet
Network size is determined from parameters, assuming
pre-activation basic-block resnet (ResNet-18 or ResNet-34)
"""
params_hkl = hkl.load(params_file)
params = OrderedDict({k: Variable(torch.from_numpy(v).cuda())
for k, v in params_hkl.items()})
blocks = [sum([re.match('group%d.block\d+.conv0.weight'%j, k) is not None
for k in list(params.keys())]) for j in range(4)]
def conv2d(input, params, base, stride=1, pad=0):
return F.conv2d(input, params[base + '.weight'], params[base + '.bias'], stride, pad)
def group(input, params, base, stride, n):
o = input
for i in range(0,n):
b_base = ('%s.block%d.conv') % (base, i)
x = o
o = conv2d(x, params, b_base + '0', pad=1, stride=i==0 and stride or 1)
o = F.relu(o, inplace=True)
o = conv2d(o, params, b_base + '1', pad=1)
if i == 0 and stride != 1:
o += F.conv2d(x, params[b_base + '_dim.weight'], stride=stride)
else:
o += x
o = F.relu(o, inplace=True)
return o
def f(inputs, params, pr=''):
inputs = Variable(inputs.data, volatile=True)
o = conv2d(inputs, params, pr+'conv0', 2, 3)
o = F.relu(o, inplace=True)
o = F.max_pool2d(o, 3, 2, 1)
o_g0 = group(o, params, pr+'group0', 1, blocks[0])
o_g1 = group(o_g0, params, pr+'group1', 2, blocks[1])
o_g2 = group(o_g1, params, pr+'group2', 2, blocks[2])
o_g3 = group(o_g2, params, pr+'group3', 2, blocks[3])
o = F.avg_pool2d(o_g3, 7, 1, 0)
o = o.view(o.size(0), -1)
o = F.linear(o, params[pr+'fc.weight'], params[pr+'fc.bias'])
return Variable(o.data), [Variable(v.data) for v in [o_g0, o_g1, o_g2, o_g3]]
return f, params
评论列表
文章目录