nin_helper.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码
def run_solver(self, unit, n_units, arm, disp_interval=100):
        #print(arm['dir'])
        caffe.set_device(self.device)
        caffe.set_mode_gpu()
        s = caffe.get_solver(arm['solver_file'])

        if arm['n_iter']>0:
            prefix=arm['dir']+"/"+str(self.problem)+"_data_iter_"
            s.restore(prefix+str(arm['n_iter'])+".solverstate")
            s.net.copy_from(prefix+str(arm['n_iter'])+".caffemodel")
            s.test_nets[0].share_with(s.net)
            s.test_nets[1].share_with(s.net)
        start=time.time()
        if unit=='time':
            while time.time()-start<n_units:
                s.step(1)
                arm['n_iter']+=1
                #print time.localtime(time.time())
        elif unit=='iter':
            n_units=min(n_units,400*150-arm['n_iter'])
            s.step(n_units)
            arm['n_iter']+=n_units
        s.snapshot()
        train_loss = s.net.blobs['loss'].data
        val_acc=0
        test_acc=0
        test_batches=100
        val_batches=100
        for i in range(val_batches):
            s.test_nets[0].forward()
            val_acc += s.test_nets[0].blobs['acc'].data
        for i in range(test_batches):
            s.test_nets[1].forward()
            test_acc += s.test_nets[1].blobs['acc'].data

        val_acc=val_acc/val_batches
        test_acc=test_acc/test_batches
        del s
        return train_loss,val_acc, test_acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号