def check_model(path=MODEL_PATH, file=SAMPLE_CSV_FILE, nsamples=2):
'''
see predictions generated for the training dataset
'''
# load model
model = load_model(path)
# load data
data, dic = get_data(file)
rows, questions, true_answers = encode_data(data, dic)
# visualize model graph
# plot_model(model, to_file='tableqa_model.png')
# predict answers
prediction = model.predict([rows[:nsamples], questions[:nsamples]])
print prediction
predicted_answers = [[np.argmax(character) for character in sample] for sample in prediction]
print predicted_answers
print true_answers[:nsamples]
# one hot encode answers
# true_answers = [to_categorical(answer, num_classes=len(dic)) for answer in answers[:nsamples]]
# decode chars from char ids int
inv_dic = {v: k for k, v in dic.iteritems()}
for i in xrange(nsamples):
print '\n'
# print 'Predicted answer: ' + ''.join([dic[char] for char in sample])
print 'Table: ' + ''.join([inv_dic[char_id] for char_id in rows[i] if char_id != 0])
print 'Question: ' + ''.join([inv_dic[char_id] for char_id in questions[i] if char_id != 0])
print 'Answer(correct): ' + ''.join([inv_dic[char_id] for char_id in true_answers[i] if char_id != 0])
print 'Answer(predicted): ' + ''.join([inv_dic[char_id] for char_id in predicted_answers[i] if char_id != 0])
评论列表
文章目录