benchmark.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyinn 作者: szagoruyko 项目源码 文件源码
def mobilenet(depth, width, depthwise_function):
    cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]

    cast = lambda x: x.cuda()

    ni = 32
    params = {'conv0': cast(kaiming_normal(torch.Tensor(ni, 3, 3, 3)))}

    for i, x in enumerate(cfg):
        no = x if isinstance(x, int) else x[0]
        params['block%d.conv0' % i] = cast(kaiming_normal(torch.Tensor(ni, 1, 3, 3)))
        params['block%d.conv1' % i] = cast(kaiming_normal(torch.Tensor(no, ni, 1, 1)))
        ni = no

    params = {k: Variable(v, requires_grad=True) for k, v in params.items()}

    def f(input, params):
        o = F.conv2d(input, params['conv0'], padding=1, stride=2)
        o = F.relu(o, inplace=True)
        for i, x in enumerate(cfg):
            stride = 1 if isinstance(x, int) else x[1]
            o = depthwise_function(o, params['block%d.conv0' % i], stride=stride, padding=1)
            o = F.conv2d(o, params['block%d.conv1' % i])
            o = F.relu(o, inplace=True)
        return o

    return f, params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号