def add_multilabel_data_layer(net, name, phase, num_classes, class_list=None):
""" Add a MultiLabelData layer """
include_dict = {'phase': phase}
param = {'num_classes': num_classes}
if phase == caffe.TRAIN:
param['stage'] = 'TRAIN'
elif phase == caffe.TEST:
param['stage'] = 'VAL'
if class_list is not None:
assert len(class_list) == num_classes, \
'Length of class list does not match number of classes {} vs {}'.\
format(len(class_list), num_classes)
param['class_list'] = class_list
param_str = yaml.dump(param)
net[name[0]], net[name[1]] = L.Python(name=name[0], python_param=dict(module='layers.multilabel_data',
layer='MultiLabelData', param_str=param_str), include=include_dict, ntop=2)
评论列表
文章目录