def get_attention(self, inputs):
'''
Takes inputs and returns pairs of synsets and corresponding attention values.
'''
if not self.attention_model:
self.define_attention_model()
attention_outputs = self.attention_model.predict(inputs)
sent_attention_values = []
for sentence_input, sentence_attention in zip(inputs, attention_outputs):
word_attention_values = []
for word_input, word_attention in zip(sentence_input, sentence_attention):
# Size of word input is (senses, hyps+1)
# Ignoring the last hyp index because that is just the word index pt there by
# OntoAwareEmbedding for sense priors.
if word_input.sum() == 0:
# This is just padding
continue
word_input = word_input[:, :-1] # removing last hyp index.
sense_hyp_prod = self.num_senses * self.num_hyps
assert len(word_attention) == sense_hyp_prod or len(word_attention) == 2 * sense_hyp_prod
attention_per_sense = []
if len(word_attention) == 2 * sense_hyp_prod:
# The encoder is Bidirectional. We have attentions from both directions.
forward_sense_attention = word_attention[:len(word_attention) // 2]
backward_sense_attention = word_attention[len(word_attention) // 2:]
processed_attention = zip(forward_sense_attention, backward_sense_attention)
else:
# Encoder is not bidirectional
processed_attention = word_attention
hyp_ind = 0
while hyp_ind < len(processed_attention):
attention_per_sense.append(processed_attention[hyp_ind:hyp_ind+self.num_hyps])
hyp_ind += self.num_hyps
sense_attention_values = []
for sense_input, attention_per_hyp in zip(word_input, attention_per_sense):
hyp_attention_values = []
for hyp_input, hyp_attention in zip(sense_input, attention_per_hyp):
if hyp_input == 0:
continue
hyp_attention_values.append((self.data_processor.get_token_from_index(hyp_input,
onto_aware=True),
hyp_attention))
sense_attention_values.append(hyp_attention_values)
word_attention_values.append(sense_attention_values)
sent_attention_values.append(word_attention_values)
return sent_attention_values
评论列表
文章目录