def gen_heatmap(model_name):
evaluator, valid_stream, ds = build_evaluator(model_name)
analysis_path = os.path.join('heatmap_analysis', model_name + ".html")
out_file = open(analysis_path, 'w')
out_file.write('<html>')
out_file.write('<body style="background-color:white">')
printed = 0;
for batch in valid_stream.get_epoch_iterator(as_dict=True):
if batch["context"].shape[1] > 150:
continue;
evaluator.initialize_aggregators()
evaluator.process_batch(batch)
analysis_results = evaluator.get_aggregated_values()
q_c_attention = analysis_results["question_context_attention"]
context_words = [ds.vocab[i]+' '+str(index) for index,i in enumerate(batch["context"][0])]
question_words = [str(index)+' '+ ds.vocab[i] for index, i in enumerate(batch["question"][0])]
answer_words = [ds.vocab[i] for i in batch["answer"][0]]
out_file.write('answer: '+' '.join(answer_words))
out_file.write('<br>')
x= context_words
y= question_words
z = q_c_attention[0]
# print z.shape
data = [
go.Heatmap(z=z,x=x,y=y,colorscale='Viridis')
]
div = plotly.offline.plot(data,auto_open=False, output_type='div')
out_file.write(div)
out_file.write('<br>')
out_file.write('<br>')
printed += 1
if printed >= 20:
break;
out_file.write('</body>')
out_file.write('</html>')
out_file.close()
print "done ;)"
评论列表
文章目录