inference-sample-error-analysis.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def format_lines(video_ids, predictions, labels, top_k):
  batch_size = len(video_ids)
  for video_index in range(batch_size):
    n_recall = max(int(numpy.sum(labels[video_index])), 1)
    # labels
    label_indices = numpy.argpartition(labels[video_index], -n_recall)[-n_recall:]
    label_predictions = [(class_index, predictions[video_index][class_index]) 
                           for class_index in label_indices]
    label_predictions = sorted(label_predictions, key=lambda p: -p[1])
    label_str = "\t".join(["%d\t%f"%(x,y) for x,y in label_predictions])
    # predictions
    top_k_indices = numpy.argpartition(predictions[video_index], -top_k)[-top_k:]
    top_k_predictions = [(class_index, predictions[video_index][class_index])
                         for class_index in top_k_indices]
    top_k_predictions = sorted(top_k_predictions, key=lambda p: -p[1])
    top_k_str = "\t".join(["%d\t%f"%(x,y) for x,y in top_k_predictions])
    # compute PERR
    top_n_indices = numpy.argpartition(predictions[video_index], -n_recall)[-n_recall:]
    positives = [labels[video_index][class_index] 
                 for class_index in top_n_indices]
    perr = sum(positives) / float(n_recall)
    # URL
    url = "https://www.youtube.com/watch?v=" + video_ids[video_index].decode('utf-8')
    yield url + "\t" + str(1-perr) + "\t" + top_k_str + "\t" + label_str + "\n"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号