def paramscheck(self):
torch_version = torch.__version__
torch_version_split = torch_version.split("_")
if torch_version_split[0] != "0.1.10":
self.drawNetwork = False
print "|===>DrawNetwork is unsupported by PyTorch with version: ", torch_version
if self.netType == "LeNet":
self.save_path = "log_%s_%s_%s/" % (self.netType, self.data_set, self.experimentID)
else:
self.save_path = "log_%s_%s_%d_%s/" % (self.netType, self.data_set,
self.depth, self.experimentID)
if self.useDefaultSetting:
print("|===> Use Default Setting")
if self.data_set == "cifar10" or self.data_set == "cifar100":
if self.nEpochs == 160:
self.LR = 0.5
self.lrPolicy = "exp"
self.momentum = 0.9
self.weightDecay = 1e-4
self.step = 2.0
self.gamma = math.pow(0.001 / self.LR, 1.0/math.floor(self.nEpochs/self.step))
else:
self.LR = 0.1
self.lrPolicy = "multistep"
self.momentum = 0.9
self.weightDecay = 1e-4
else:
assert False, "invalid data set"
if self.data_set == "cifar10" or self.data_set == "mnist":
self.nClasses = 10
elif self.data_set == "cifar100":
self.nClasses = 100
评论列表
文章目录