def mobilenet(depth, width, depthwise_function):
cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]
cast = lambda x: x.cuda()
ni = 32
params = {'conv0': cast(kaiming_normal(torch.Tensor(ni, 3, 3, 3)))}
for i, x in enumerate(cfg):
no = x if isinstance(x, int) else x[0]
params['block%d.conv0' % i] = cast(kaiming_normal(torch.Tensor(ni, 1, 3, 3)))
params['block%d.conv1' % i] = cast(kaiming_normal(torch.Tensor(no, ni, 1, 1)))
ni = no
params = {k: Variable(v, requires_grad=True) for k, v in params.items()}
def f(input, params):
o = F.conv2d(input, params['conv0'], padding=1, stride=2)
o = F.relu(o, inplace=True)
for i, x in enumerate(cfg):
stride = 1 if isinstance(x, int) else x[1]
o = depthwise_function(o, params['block%d.conv0' % i], stride=stride, padding=1)
o = F.conv2d(o, params['block%d.conv1' % i])
o = F.relu(o, inplace=True)
return o
return f, params
评论列表
文章目录