prepare_data.py 文件源码

python
阅读 53 收藏 0 点赞 0 评论 0

项目:NeuralNetwork-ImageQA 作者: ayushoriginal 项目源码 文件源码
def get_answers_matrix(split):
    if split == 'train':
        data_path = 'data/train_qa'
    elif split == 'val':
        data_path = 'data/val_qa'
    else:
        print('Invalid split!')
        sys.exit()

    df = pd.read_pickle(data_path)
    answers = df[['multiple_choice_answer']].values.tolist()
    answer_matrix = np.zeros((len(answers),1001))
    default_onehot = np.zeros(1001)
    default_onehot[1000] = 1.0

    for i, answer in enumerate(answers):
        answer_matrix[i] = answer_to_onehot_dict.get(answer[0].lower(),default_onehot)

    return answer_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号