vrnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:VRNN 作者: harryross263 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Variational recurrent neural network cell (VRNN)."""
        with tf.variable_scope(scope or type(self).__name__):
            # Update the hidden state.
            z_t, z_mean_t, z_log_sigma_sq_t = state
            h_t_1 = self._activation(_linear(
                    [inputs, z_t, z_mean_t, z_log_sigma_sq_t],
                    2 * self._num_units,
                    True))
            z_mean_t_1, z_log_sigma_sq_t_1 = tf.split(1, 2, h_t_1)

            # Sample.
            eps = tf.random_normal((tf.shape(inputs)[0], self._num_units), 0.0, 1.0,
                    dtype=tf.float32)
            z_t_1 = tf.add(z_mean_t_1, tf.mul(tf.sqrt(tf.exp(z_log_sigma_sq_t_1)),
                    eps))

            return z_t_1, VRNNStateTuple(z_t_1, z_mean_t_1, z_log_sigma_sq_t_1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号