opt.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:ExperimentPackage_PyTorch 作者: ICEORY 项目源码 文件源码
def paramscheck(self):
        torch_version = torch.__version__
        torch_version_split = torch_version.split("_")

        if torch_version_split[0] != "0.1.10":
            self.drawNetwork = False
            print "|===>DrawNetwork is unsupported by PyTorch with version: ", torch_version

        if self.netType == "LeNet":
            self.save_path = "log_%s_%s_%s/" % (self.netType, self.data_set, self.experimentID)
        else:
            self.save_path = "log_%s_%s_%d_%s/" % (self.netType, self.data_set,
                                                   self.depth, self.experimentID)

        if self.useDefaultSetting:
            print("|===> Use Default Setting")
            if self.data_set == "cifar10" or self.data_set == "cifar100":
                if self.nEpochs == 160:
                    self.LR = 0.5
                    self.lrPolicy = "exp"
                    self.momentum = 0.9
                    self.weightDecay = 1e-4
                    self.step = 2.0
                    self.gamma = math.pow(0.001 / self.LR, 1.0/math.floor(self.nEpochs/self.step))
                else:
                    self.LR = 0.1
                    self.lrPolicy = "multistep"
                    self.momentum = 0.9
                    self.weightDecay = 1e-4
            else:
                assert False, "invalid data set"

        if self.data_set == "cifar10" or self.data_set == "mnist":
            self.nClasses = 10
        elif self.data_set == "cifar100":
            self.nClasses = 100
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号