train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:TF_MemN2N-tableQA 作者: vendi12 项目源码 文件源码
def check_model(path=MODEL_PATH, file=SAMPLE_CSV_FILE, nsamples=2):
    '''
    see predictions generated for the training dataset
    '''

    # load model
    model = load_model(path)

    # load data
    data, dic = get_data(file)
    rows, questions, true_answers = encode_data(data, dic)

    # visualize model graph
    # plot_model(model, to_file='tableqa_model.png')

    # predict answers
    prediction = model.predict([rows[:nsamples], questions[:nsamples]])
    print prediction
    predicted_answers = [[np.argmax(character) for character in sample] for sample in prediction]
    print predicted_answers
    print true_answers[:nsamples]

    # one hot encode answers
    # true_answers = [to_categorical(answer, num_classes=len(dic)) for answer in answers[:nsamples]]

    # decode chars from char ids int
    inv_dic = {v: k for k, v in dic.iteritems()}
    for i in xrange(nsamples):
        print '\n'
        # print 'Predicted answer: ' + ''.join([dic[char] for char in sample])
        print 'Table: ' + ''.join([inv_dic[char_id] for char_id in rows[i] if char_id != 0])
        print 'Question: ' + ''.join([inv_dic[char_id] for char_id in questions[i] if char_id != 0])
        print 'Answer(correct): ' + ''.join([inv_dic[char_id] for char_id in true_answers[i] if char_id != 0])
        print 'Answer(predicted): ' + ''.join([inv_dic[char_id] for char_id in predicted_answers[i] if char_id != 0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号