baseline.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:shalo 作者: henryre 项目源码 文件源码
def _embed_sentences(self):
        """Embed sentences via the last output cell of an LSTM"""
        word_embeddings = self._get_embedding()
        word_feats      = tf.nn.embedding_lookup(word_embeddings, self.input)
        batch_size      = tf.shape(self.input)[0]
        with tf.variable_scope("LSTM") as scope:
            tf.set_random_seed(self.seed - 1)
            # LSTM architecture
            cell = tf.contrib.rnn.BasicLSTMCell(self.d)
            # Set RNN
            initial_state = cell.zero_state(batch_size, tf.float32)
            rnn_out, _ = tf.nn.dynamic_rnn(
                cell, word_feats, sequence_length=self.input_lengths,
                initial_state=initial_state, time_major=False               
            )
        # Get potentials
        return get_rnn_output(rnn_out, self.d, self.input_lengths), {}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号