imagenet.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:attention-transfer 作者: szagoruyko 项目源码 文件源码
def define_teacher(params_file):
    """ Defines student resnet

        Network size is determined from parameters, assuming
        pre-activation basic-block resnet (ResNet-18 or ResNet-34)
    """
    params_hkl = hkl.load(params_file)

    params = OrderedDict({k: Variable(torch.from_numpy(v).cuda())
                          for k, v in params_hkl.items()})

    blocks = [sum([re.match('group%d.block\d+.conv0.weight'%j, k) is not None
                   for k in list(params.keys())]) for j in range(4)]

    def conv2d(input, params, base, stride=1, pad=0):
        return F.conv2d(input, params[base + '.weight'], params[base + '.bias'], stride, pad)

    def group(input, params, base, stride, n):
        o = input
        for i in range(0,n):
            b_base = ('%s.block%d.conv') % (base, i)
            x = o
            o = conv2d(x, params, b_base + '0', pad=1, stride=i==0 and stride or 1)
            o = F.relu(o, inplace=True)
            o = conv2d(o, params, b_base + '1', pad=1)
            if i == 0 and stride != 1:
                o += F.conv2d(x, params[b_base + '_dim.weight'], stride=stride)
            else:
                o += x
            o = F.relu(o, inplace=True)
        return o

    def f(inputs, params, pr=''):
        inputs = Variable(inputs.data, volatile=True)
        o = conv2d(inputs, params, pr+'conv0', 2, 3)
        o = F.relu(o, inplace=True)
        o = F.max_pool2d(o, 3, 2, 1)
        o_g0 = group(o, params, pr+'group0', 1, blocks[0])
        o_g1 = group(o_g0, params, pr+'group1', 2, blocks[1])
        o_g2 = group(o_g1, params, pr+'group2', 2, blocks[2])
        o_g3 = group(o_g2, params, pr+'group3', 2, blocks[3])
        o = F.avg_pool2d(o_g3, 7, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr+'fc.weight'], params[pr+'fc.bias'])
        return Variable(o.data), [Variable(v.data) for v in [o_g0, o_g1, o_g2, o_g3]]

    return f, params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号