model_pp_attachment.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def define_attention_model(self):
        '''
        Take necessary parts out of the model to get OntoLSTM attention.
        '''
        if not self.model:
            raise RuntimeError("Model not trained yet!")
        input_shape = self.model.get_input_shape_at(0)
        input_layer = Input(input_shape[1:], dtype='int32')  # removing batch size
        embedding_layer = None
        encoder_layer = None
        for layer in self.model.layers:
            if layer.name == "embedding":
                embedding_layer = layer
            elif layer.name == "onto_lstm":
                # We need to redefine the OntoLSTM layer with the learned weights and set return attention to True.
                # Assuming we'll want attention values for all words (return_sequences = True)
                if isinstance(layer, Bidirectional):
                    onto_lstm = OntoAttentionLSTM(input_dim=self.embed_dim, output_dim=self.embed_dim,
                                                  num_senses=self.num_senses, num_hyps=self.num_hyps,
                                                  use_attention=True, return_attention=True, return_sequences=True,
                                                  consume_less='gpu')
                    encoder_layer = Bidirectional(onto_lstm, weights=layer.get_weights())
                else:
                    encoder_layer = OntoAttentionLSTM(input_dim=self.embed_dim,
                                                      output_dim=self.embed_dim, num_senses=self.num_senses,
                                                      num_hyps=self.num_hyps, use_attention=True,
                                                      return_attention=True, return_sequences=True,
                                                      consume_less='gpu', weights=layer.get_weights())
                break
        if not embedding_layer or not encoder_layer:
            raise RuntimeError("Required layers not found!")
        attention_output = encoder_layer(embedding_layer(input_layer))
        self.attention_model = Model(inputs=input_layer, outputs=attention_output)
        print >>sys.stderr, "Attention model summary:"
        self.attention_model.summary()
        self.attention_model.compile(loss="mse", optimizer="sgd")  # Loss and optimizer do not matter!
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号