stacked_simple.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:deep-summarization 作者: harpribot 项目源码 文件源码
def _load_model(self):
        """
        Creates the encoder decoder model

        :return: None
        """
        # Initial memory value for recurrence.
        self.prev_mem = tf.zeros((self.train_batch_size, self.memory_dim))

        # choose RNN/GRU/LSTM cell
        with tf.variable_scope("train_test", reuse=True):
            cell = self.get_cell()
            # Stacks layers of RNN's to form a stacked decoder
            self.cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.num_layers)

        # embedding model
        if not self.attention:
            with tf.variable_scope("train_test"):
                self.dec_outputs, self.dec_memory = tf.nn.seq2seq.embedding_rnn_seq2seq(
                                self.enc_inp, self.dec_inp, self.cell,
                                self.vocab_size, self.vocab_size, self.seq_length)
            with tf.variable_scope("train_test", reuse=True):
                self.dec_outputs_tst, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
                                self.enc_inp, self.dec_inp, self.cell,
                                self.vocab_size, self.vocab_size, self.seq_length, feed_previous=True)

        else:
            with tf.variable_scope("train_test"):
                self.dec_outputs, self.dec_memory = tf.nn.seq2seq.embedding_attention_seq2seq(
                                self.enc_inp, self.dec_inp, self.cell,
                                self.vocab_size, self.vocab_size, self.seq_length)
            with tf.variable_scope("train_test", reuse=True):
                self.dec_outputs_tst, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
                                self.enc_inp, self.dec_inp, self.cell,
                                self.vocab_size, self.vocab_size, self.seq_length, feed_previous=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号