def get_answers_matrix(split):
if split == 'train':
data_path = 'data/train_qa'
elif split == 'val':
data_path = 'data/val_qa'
else:
print('Invalid split!')
sys.exit()
df = pd.read_pickle(data_path)
answers = df[['multiple_choice_answer']].values.tolist()
answer_matrix = np.zeros((len(answers),1001))
default_onehot = np.zeros(1001)
default_onehot[1000] = 1.0
for i, answer in enumerate(answers):
answer_matrix[i] = answer_to_onehot_dict.get(answer[0].lower(),default_onehot)
return answer_matrix
prepare_data.py 文件源码
python
阅读 53
收藏 0
点赞 0
评论 0
评论列表
文章目录