model_sentences.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def get_attention(self, C_ind):
    if not self.model:
      raise RuntimeError, "Model not trained!"
    model_embedding = None
    model_weights = None
    for layer in self.model.layers:
      if layer.name.lower() == "embedding":
        model_embedding = layer
      if layer.name.lower() == "sent_lstm":
        model_lstm = layer
    if model_embedding is None or model_lstm is None:
      raise RuntimeError, "Did not find expected layers"
    lstm_weights = model_lstm.get_weights()
    embedding_weights = model_embedding.get_weights()
    embed_in_dim, embed_out_dim = embedding_weights[0].shape
    att_embedding = HigherOrderEmbedding(input_dim=embed_in_dim, output_dim=embed_out_dim, weights=embedding_weights)
    onto_lstm = OntoAttentionLSTM(input_dim=embed_out_dim, output_dim=embed_out_dim, input_length=model_lstm.input_length, num_senses=self.num_senses, num_hyps=self.num_hyps, use_attention=True, return_attention=True, weights=lstm_weights)
    att_input = Input(shape=C_ind.shape[1:], dtype='int32')
    att_sent_rep = att_embedding(att_input)
    att_output = onto_lstm(att_sent_rep)
    att_model = Model(input=att_input, output=att_output)
    att_model.compile(optimizer='adam', loss='mse') # optimizer and loss are not needed since we are not going to train this model.
    C_att = att_model.predict(C_ind)
    print >>sys.stderr, "Got attention values. Input, output shapes:", C_ind.shape, C_att.shape
    return C_att
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号