def generate_train_proto(self, model_fn, fts_lmdb, sem_lmdb, batch_size):
ns = self._new_model()
# Inputs
mean = [104., 116., 122.]
stage = {'testRecg': 'TestRecognition',
'testZS': 'TestZeroShot'}
for subset in ['train', 'testRecg', 'testZS']:
if subset == 'train':
include = {'phase': caffe.TRAIN}
else:
include = {'phase': caffe.TEST, 'stage': stage[subset]}
ns[subset+'_data'], ns[subset+'_labels'] = L.Data(name='data', ntop=2, top=['data', 'labels'], in_place=True,
source=fts_lmdb[subset], batch_size=batch_size, backend=P.Data.LMDB,
transform_param=dict(mirror=True if subset == 'train' else False,
crop_size=self.base_cnn.input_size,
mean_value=mean),
include=include)
# Semantic labels for training
if self.sem_coeff > 0:
ns.semantics = L.Data(name='semantics',
source=sem_lmdb['train'], batch_size=batch_size, backend=P.Data.LMDB,
include=dict(phase=caffe.TRAIN))
# Run base CNN
xFt = self.base_cnn.inference_proto(ns.train_data, mult=1., truncate_at=self.feat_layer)
# Run score
xObj, xSem, lCW = self._score_proto(xFt, source_net=True, target_net=self.test_classes is not None, mult=1.0)
self.scores = {'obj': xObj, 'semantics': xSem}
# Loss
self._loss_proto(ns[xObj], ns.train_labels, ns[xSem], ns.semantics if self.sem_coeff > 0 else None, lCW)
# Evaluation
self._eval_proto(ns[xObj], ns.train_labels)
with open(model_fn, 'w') as f:
f.write(str(ns.to_proto()))
评论列表
文章目录