basic_rnn_cells.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def trainable_initial_state(self, batch_size):
        """
        Create a trainable initial state for the BasicLSTMCell
        :param batch_size: number of samples per batch
        :return: LSTMStateTuple
        """
        def _create_initial_state(batch_size, state_size, trainable=True, initializer=tf.random_normal_initializer()):
            with tf.device('/cpu:0'):
                s = tf.get_variable('initial_state', shape=[1, state_size], dtype=tf.float32, trainable=trainable,
                                    initializer=initializer)
                state = tf.tile(s, tf.stack([batch_size] + [1]))
            return state

        with tf.variable_scope('initial_c'):
            initial_c = _create_initial_state(batch_size, self._num_units)
        with tf.variable_scope('initial_h'):
            initial_h = _create_initial_state(batch_size, self._num_units)
        return tf.contrib.rnn.LSTMStateTuple(initial_c, initial_h)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号