def decode_predictions(preds, top=5):
LABELS = None
if len(preds.shape) == 2:
if preds.shape[1] == 2622:
fpath = get_file('rcmalli_vggface_labels_v1.npy',
V1_LABELS_PATH,
cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
elif preds.shape[1] == 8631:
fpath = get_file('rcmalli_vggface_labels_v2.npy',
V2_LABELS_PATH,
cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
else:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 2622)) for V1 or '
'(samples, 8631) for V2.'
'Found array with shape: ' + str(preds.shape))
else:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 2622)) for V1 or '
'(samples, 8631) for V2.'
'Found array with shape: ' + str(preds.shape))
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [[str(LABELS[i].encode('utf8')), pred[i]] for i in top_indices]
result.sort(key=lambda x: x[1], reverse=True)
results.append(result)
return results
评论列表
文章目录