seq2seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:nlvr_tau_nlp_final_proj 作者: udiNaveh 项目源码 文件源码
def build_sentence_encoder(vocabulary_size):
    """
    build the computational graph for the lstm sentence encoder. Return only the palceholders and tensors
    that are called from other methods
    """
    sentence_oh_placeholder = tf.placeholder(shape=[None, vocabulary_size], dtype=tf.float32,
                                             name="sentence_placeholder")
    word_embeddings_matrix = tf.get_variable("W_we",  # shape=[vocabulary_size, WORD_EMB_SIZE]
                                             initializer=tf.constant(embeddings_matrix, dtype=tf.float32))
    sentence_embedded = tf.expand_dims(tf.matmul(sentence_oh_placeholder, word_embeddings_matrix), 0)
    # placeholders for sentence and it's length
    sent_lengths = tf.placeholder(dtype=tf.int32, name="sent_length_placeholder")

    # Forward cell
    lstm_fw_cell = BasicLSTMCell(LSTM_HIDDEN_SIZE, forget_bias=1.0)
    # Backward cell
    lstm_bw_cell = BasicLSTMCell(LSTM_HIDDEN_SIZE, forget_bias=1.0)
    # stack cells together in RNN
    outputs, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, sentence_embedded, sent_lengths,
                                                 dtype=tf.float32)
    #    outputs: A tuple (output_fw, output_bw) containing the forward and the backward rnn output `Tensor`.
    #    both output_fw, output_bw will be a `Tensor` shaped: [batch_size, max_time, cell_fw.output_size]`

    # outputs is a (output_forward,output_backwards) tuple. concat them together to receive h vector
    lstm_outputs = tf.concat(outputs, 2)[0]  # shape: [max_time, 2 * hidden_layer_size ]
    final_fw = outputs[0][:, -1, :]
    final_bw = outputs[1][:, 0, :]
    e_m = tf.concat((final_fw, final_bw), axis=1)
    sentence_words_bow = tf.placeholder(tf.float32, [None, len(words_vocabulary)], name="sentence_words_bow")
    e_m_with_bow = tf.concat([e_m, sentence_words_bow], axis=1)

    return sentence_oh_placeholder, sent_lengths, sentence_words_bow, lstm_outputs, e_m_with_bow
    # TODO return sentence_oh_placeholder, sent_lengths, sentence_words_bow, lstm_outputs, e_m
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号