encoders.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def encode(self, inputs, _input_length, _parses):
        with tf.variable_scope('BagOfWordsEncoder'):
            W = tf.get_variable('W', (self.embed_size, self.output_size))
            b = tf.get_variable('b', shape=(self.output_size,), initializer=tf.constant_initializer(0, tf.float32))

            enc_hidden_states = tf.tanh(tf.tensordot(inputs, W, [[2], [0]]) + b)
            enc_final_state = tf.reduce_sum(enc_hidden_states, axis=1)

            #assert enc_hidden_states.get_shape()[1:] == (self.config.max_length, self.config.hidden_size)
            if self._cell_type == 'lstm':
                enc_final_state = (tf.contrib.rnn.LSTMStateTuple(enc_final_state, enc_final_state),)

            enc_output = tf.nn.dropout(enc_hidden_states, keep_prob=self._dropout, seed=12345)

            return enc_output, enc_final_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号