def print_attention_values(self, input_file, test_inputs, output_file):
sent_attention_outputs = self.get_attention(test_inputs)
tagged_sentences = [x.strip().split("\t")[1] for x in codecs.open(input_file).readlines()]
outfile = codecs.open(output_file, "w", "utf-8")
full_json_struct = []
for sent_attention, tagged_sentence in zip(sent_attention_outputs, tagged_sentences):
sent_json = {}
sent_json["input"] = tagged_sentence
sent_json["tokens"] = []
tagged_words = tagged_sentence.split()
for tagged_word, word_attention in zip(tagged_words, sent_attention):
token_json = {}
token_json["surface_form"] = tagged_word
token_json["senses"] = []
for sense_num, sense_attention in enumerate(word_attention):
if len(sense_attention) == 0:
continue
sense_json = {}
sense_json["id"] = sense_num
sense_json["hypernyms"] = []
for hyp_name, hyp_att in sense_attention:
if isinstance(hyp_att, tuple):
# Averaging forward and backward attention
sense_json["hypernyms"].append({hyp_name: {"forward": float(hyp_att[0]),
"backward": float(hyp_att[1])}})
else:
sense_json["hypernyms"].append({hyp_name: float(hyp_att)})
token_json["senses"].append(sense_json)
sent_json["tokens"].append(token_json)
full_json_struct.append(sent_json)
print >>outfile, json.dumps(full_json_struct, indent=2)
outfile.close()
评论列表
文章目录