model_pp_attachment.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def get_attention(self, inputs):
        '''
        Takes inputs and returns pairs of synsets and corresponding attention values.
        '''
        if not self.attention_model:
            self.define_attention_model()
        attention_outputs = self.attention_model.predict(inputs)
        sent_attention_values = []
        for sentence_input, sentence_attention in zip(inputs, attention_outputs):
            word_attention_values = []
            for word_input, word_attention in zip(sentence_input, sentence_attention):
                # Size of word input is (senses, hyps+1)
                # Ignoring the last hyp index because that is just the word index pt there by
                # OntoAwareEmbedding for sense priors.
                if word_input.sum() == 0:
                    # This is just padding
                    continue
                word_input = word_input[:, :-1]  # removing last hyp index.
                sense_hyp_prod = self.num_senses * self.num_hyps
                assert len(word_attention) == sense_hyp_prod or len(word_attention) == 2 * sense_hyp_prod
                attention_per_sense = []
                if len(word_attention) == 2 * sense_hyp_prod:
                    # The encoder is Bidirectional. We have attentions from both directions.
                    forward_sense_attention = word_attention[:len(word_attention) // 2]
                    backward_sense_attention = word_attention[len(word_attention) // 2:]
                    processed_attention = zip(forward_sense_attention, backward_sense_attention)
                else:
                    # Encoder is not bidirectional
                    processed_attention = word_attention
                hyp_ind = 0
                while hyp_ind < len(processed_attention):
                    attention_per_sense.append(processed_attention[hyp_ind:hyp_ind+self.num_hyps])
                    hyp_ind += self.num_hyps

                sense_attention_values = []
                for sense_input, attention_per_hyp in zip(word_input, attention_per_sense):
                    hyp_attention_values = []
                    for hyp_input, hyp_attention in zip(sense_input, attention_per_hyp):
                        if hyp_input == 0:
                            continue
                        hyp_attention_values.append((self.data_processor.get_token_from_index(hyp_input,
                                                                                              onto_aware=True),
                                                     hyp_attention))
                    sense_attention_values.append(hyp_attention_values)
                word_attention_values.append(sense_attention_values)
            sent_attention_values.append(word_attention_values)
        return sent_attention_values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号