analyse.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Question-Answering 作者: arianhosseini 项目源码 文件源码
def gen_heatmap(model_name):

    evaluator, valid_stream, ds = build_evaluator(model_name)
    analysis_path = os.path.join('heatmap_analysis', model_name + ".html")

    out_file = open(analysis_path, 'w')
    out_file.write('<html>')
    out_file.write('<body style="background-color:white">')

    printed = 0;
    for batch in valid_stream.get_epoch_iterator(as_dict=True):

        if batch["context"].shape[1] > 150:
            continue;

        evaluator.initialize_aggregators()
        evaluator.process_batch(batch)
        analysis_results = evaluator.get_aggregated_values()
        q_c_attention = analysis_results["question_context_attention"]

        context_words = [ds.vocab[i]+' '+str(index) for index,i in enumerate(batch["context"][0])]
        question_words = [str(index)+' '+ ds.vocab[i] for index, i in enumerate(batch["question"][0])]
        answer_words = [ds.vocab[i] for i in batch["answer"][0]]

        out_file.write('answer: '+' '.join(answer_words))
        out_file.write('<br>')

        x= context_words
        y= question_words
        z = q_c_attention[0]
        # print z.shape

        data = [
            go.Heatmap(z=z,x=x,y=y,colorscale='Viridis')
        ]
        div = plotly.offline.plot(data,auto_open=False, output_type='div')
        out_file.write(div)
        out_file.write('<br>')
        out_file.write('<br>')

        printed += 1
        if printed >= 20:
            break;


    out_file.write('</body>')
    out_file.write('</html>')
    out_file.close()
    print "done ;)"
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号