score_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:score-zeroshot 作者: pedro-morgado 项目源码 文件源码
def generate_train_proto(self, model_fn, fts_lmdb, sem_lmdb, batch_size):
        ns = self._new_model()

        # Inputs
        mean = [104., 116., 122.]
        stage = {'testRecg': 'TestRecognition',
                 'testZS': 'TestZeroShot'}
        for subset in ['train', 'testRecg', 'testZS']:
            if subset == 'train':
                include = {'phase': caffe.TRAIN}
            else:
                include = {'phase': caffe.TEST, 'stage': stage[subset]}
            ns[subset+'_data'], ns[subset+'_labels'] = L.Data(name='data', ntop=2, top=['data', 'labels'], in_place=True,
                                                              source=fts_lmdb[subset], batch_size=batch_size, backend=P.Data.LMDB,
                                                              transform_param=dict(mirror=True if subset == 'train' else False,
                                                                                   crop_size=self.base_cnn.input_size,
                                                                                   mean_value=mean),
                                                              include=include)

        # Semantic labels for training
        if self.sem_coeff > 0:
            ns.semantics = L.Data(name='semantics',
                                  source=sem_lmdb['train'], batch_size=batch_size, backend=P.Data.LMDB,
                                  include=dict(phase=caffe.TRAIN))

        # Run base CNN
        xFt = self.base_cnn.inference_proto(ns.train_data, mult=1., truncate_at=self.feat_layer)

        # Run score
        xObj, xSem, lCW = self._score_proto(xFt, source_net=True, target_net=self.test_classes is not None, mult=1.0)
        self.scores = {'obj': xObj, 'semantics': xSem}

        # Loss
        self._loss_proto(ns[xObj], ns.train_labels, ns[xSem], ns.semantics if self.sem_coeff > 0 else None, lCW)

        # Evaluation
        self._eval_proto(ns[xObj], ns.train_labels)

        with open(model_fn, 'w') as f:
            f.write(str(ns.to_proto()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号