def format_lines(video_ids, predictions, labels, top_k):
batch_size = len(video_ids)
for video_index in range(batch_size):
n_recall = max(int(numpy.sum(labels[video_index])), 1)
# labels
label_indices = numpy.argpartition(labels[video_index], -n_recall)[-n_recall:]
label_predictions = [(class_index, predictions[video_index][class_index])
for class_index in label_indices]
label_predictions = sorted(label_predictions, key=lambda p: -p[1])
label_str = "\t".join(["%d\t%f"%(x,y) for x,y in label_predictions])
# predictions
top_k_indices = numpy.argpartition(predictions[video_index], -top_k)[-top_k:]
top_k_predictions = [(class_index, predictions[video_index][class_index])
for class_index in top_k_indices]
top_k_predictions = sorted(top_k_predictions, key=lambda p: -p[1])
top_k_str = "\t".join(["%d\t%f"%(x,y) for x,y in top_k_predictions])
# compute PERR
top_n_indices = numpy.argpartition(predictions[video_index], -n_recall)[-n_recall:]
positives = [labels[video_index][class_index]
for class_index in top_n_indices]
perr = sum(positives) / float(n_recall)
# URL
url = "https://www.youtube.com/watch?v=" + video_ids[video_index].decode('utf-8')
yield url + "\t" + str(1-perr) + "\t" + top_k_str + "\t" + label_str + "\n"
inference-sample-error-analysis.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录