def resnet_orth_v2(n=3):
"""6n+2, n=3 9 18 coresponds to 20 56 110 layers"""
net_name = "resnet-orth-v2"
pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
name = net_name+str(6*n+2)+'-cifar10'
if n > 18:
# warm up
solver = Solver(solver_name="solver_warm.prototxt", folder=pt_folder, lr_policy=Solver.policy.fixed)
solver.p.base_lr = 0.01
solver.set_max_iter(500)
solver.write()
del solver
solver = Solver(folder=pt_folder)
solver.write()
del solver
builder = Net(name)
builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
builder.Data('cifar-10-batches-py/test', phase='TEST')
builder.resnet_cifar(n, orth=True, v2=True)
builder.write(folder=pt_folder)
评论列表
文章目录