coref_model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:e2e-coref 作者: kentonl 项目源码 文件源码
def encode_sentences(self, text_emb, text_len, text_len_mask):
    num_sentences = tf.shape(text_emb)[0]
    max_sentence_length = tf.shape(text_emb)[1]

    # Transpose before and after for efficiency.
    inputs = tf.transpose(text_emb, [1, 0, 2]) # [max_sentence_length, num_sentences, emb]

    with tf.variable_scope("fw_cell"):
      cell_fw = util.CustomLSTMCell(self.config["lstm_size"], num_sentences, self.dropout)
      preprocessed_inputs_fw = cell_fw.preprocess_input(inputs)
    with tf.variable_scope("bw_cell"):
      cell_bw = util.CustomLSTMCell(self.config["lstm_size"], num_sentences, self.dropout)
      preprocessed_inputs_bw = cell_bw.preprocess_input(inputs)
      preprocessed_inputs_bw = tf.reverse_sequence(preprocessed_inputs_bw,
                                                   seq_lengths=text_len,
                                                   seq_dim=0,
                                                   batch_dim=1)
    state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))
    state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))
    with tf.variable_scope("lstm"):
      with tf.variable_scope("fw_lstm"):
        fw_outputs, fw_states = tf.nn.dynamic_rnn(cell=cell_fw,
                                                  inputs=preprocessed_inputs_fw,
                                                  sequence_length=text_len,
                                                  initial_state=state_fw,
                                                  time_major=True)
      with tf.variable_scope("bw_lstm"):
        bw_outputs, bw_states = tf.nn.dynamic_rnn(cell=cell_bw,
                                                  inputs=preprocessed_inputs_bw,
                                                  sequence_length=text_len,
                                                  initial_state=state_bw,
                                                  time_major=True)

    bw_outputs = tf.reverse_sequence(bw_outputs,
                                     seq_lengths=text_len,
                                     seq_dim=0,
                                     batch_dim=1)

    text_outputs = tf.concat([fw_outputs, bw_outputs], 2)
    text_outputs = tf.transpose(text_outputs, [1, 0, 2]) # [num_sentences, max_sentence_length, emb]
    return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号