nettrainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deep-prior 作者: moberweger 项目源码 文件源码
def __init__(self, cfgParams, memory_factor):
        """
        Constructor
        :param cfgParams: initialized NetTrainerParams
        :param memory_factor: fraction of memory used for single shared variable
        """

        self.cfgParams = cfgParams

        if not isinstance(cfgParams, NetTrainerParams):
            raise ValueError("cfgParams must be an instance of NetTrainerParams")

        if 'gpu' in theano.config.device:
            # get GPU memory info
            mem_info = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray.mem_info()
            self.memory = (mem_info[0] / 1024 ** 2) / float(memory_factor)  # MB, use third of free memory
        elif 'cpu' in theano.config.device:
            # get CPU memory info
            self.memory = (psutil.virtual_memory().available / 1024 ** 2) / float(memory_factor)  # MB, use third of free memory
        else:
            raise EnvironmentError("Neither GPU nor CPU device in theano.config found!")

        self.currentMacroBatch = -1  # current batch on GPU, load on first run
        self.trainSize = 0
        self.sampleSize = 0
        self.numTrainSamples = 0
        self.numValSamples = 0
        self.managedVar = []
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号