nettrainer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:deep-prior-pp 作者: moberweger 项目源码 文件源码
def __init__(self, cfgParams, memory_factor, subfolder='./eval/', numChunks=1):
        """
        Constructor
        :param cfgParams: initialized NetTrainerParams
        :param memory_factor: fraction of memory used for single shared variable
        """

        self.subfolder = subfolder
        self.cfgParams = cfgParams
        self.rng = numpy.random.RandomState(23455)

        if not isinstance(cfgParams, NetTrainerParams):
            raise ValueError("cfgParams must be an instance of NetTrainerParams")

        # use fraction of free memory
        if 'gpu' in theano.config.device:
            # get GPU memory info
            mem_info = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray.mem_info()
            self.memorySize = (mem_info[0] / 1024 ** 2) / float(memory_factor)  # MB
        elif 'cpu' in theano.config.device:
            # get CPU memory info
            self.memorySize = (psutil.virtual_memory().available / 1024 ** 2) / float(memory_factor)  # MB
        else:
            raise EnvironmentError("Neither GPU nor CPU device in theano.config found!")

        if cfgParams.para_load is True and numChunks == 1:
            raise ValueError("para_load is True but numChunks == 1, so we do not need para_load!")

        self.currentMacroBatch = -1  # current batch on GPU, load on first run
        self.currentChunk = -1  # current chunk in RAM, load on first run
        self.numChunks = numChunks
        self.trainSize = 0
        self.sampleSize = 0
        self.numTrainSamplesMB = 0
        self.numTrainSamples = 0
        self.numValSamples = 0
        self.epoch = 0
        self.managedVar = []
        self.trainingVar = []
        self.validation_observer = []
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号