def get_label_using_logits_batch(question_id_sublist, logits_batch, vocabulary_index2word_label, f, top_number=5):
print("get_label_using_logits.shape:", np.array(logits_batch).shape) # (1, 128, 2002))
for i, logits in enumerate(logits_batch):
index_list = np.argsort(logits)[-top_number:]
#print("index_list:",index_list)
index_list = index_list[::-1]
label_list = []
for index in index_list:
#print("index:",index)
label = vocabulary_index2word_label[index]
label_list.append(
label) # ('get_label_using_logits.label_list:', [u'-3423450385060590478', u'2838091149470021485', u'-3174907002942471215', u'-1812694399780494968', u'6815248286057533876'])
# print("get_label_using_logits.label_list",label_list)
write_question_id_with_labels(question_id_sublist[i], label_list, f)
f.flush()
# get label using logits
a2_predict_classification.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录