model_chainer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:jrm_ssl 作者: Fhrozen 项目源码 文件源码
def __init__(self, config, Network):
        InitOpt= config.get('network', 'init_opt')
        InitOpt= [int(x) for x in InitOpt.split(';')] 


        if config.getboolean('gpu', 'use'):
            list_gpu = config.get('gpu', 'index')
            print('Configuring the training for GPU calculation:')
            print('  using gpus: {}'.format(list_gpu))
            self.list_gpu = [int(x) for x in list_gpu.split(';')  ]
            chainer.cuda.get_device(self.list_gpu[0]).use()
            self.xp = chainer.cuda
            self.Networks = [Network(InitOpt)] 

        else:
            print('Configuring the training for CPU calculation:')
            self.xp = np
            self.list_gpu = []
            self.Networks[0].train = True  
        self.Optimizer = optimizers.Adam(alpha=config.getfloat('train', 'learning_rate')) #TODO: Set type of Optimizer on Config File
        _inputs = config.get('data', 'labels')
        _inputs = [ x for x in _inputs.split(';')]
        self._inputs = len( _inputs)
        self._gaussian = config.getboolean('train', 'gaussian')
        if self._gaussian: self.eta = config.getfloat('train', 'eta_gn')
        self._lasso = config.getboolean('train', 'lasso')
        if self._lasso: self.lasso_dy = config.getfloat('train', 'decay_lasso')
        try: #only set on Recurrent Network
            self.sequence = config.getint('data', 'sequence')
            self.clip_threshold = config.getfloat('train', 'clip_threshold')
            self._use_clip = config.getboolean('train', 'use_clip')
            self._lstm = True 
            print('  Setting Network for Sequential Training...') 
        except:
            self._use_clip = False
            self._lstm = False
        self.train = False
        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号