KalmanVariationalAutoencoder.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kvae 作者: simonkamronn 项目源码 文件源码
def build_model(self):
        # Encoder q(a|x)
        a_seq, a_mu, a_var = self.encoder(self.x)
        a_vae = a_seq

        # Initial state for the alpha RNN
        dummy_lstm = BasicLSTMCell(self.config.alpha_units * 2 if self.config.learn_u else self.config.alpha_units)
        state_init_rnn = dummy_lstm.zero_state(self.config.batch_size, tf.float32)

        # Initialize Kalman filter (LGSSM)
        self.kf = KalmanFilter(dim_z=self.config.dim_z,
                               dim_y=self.config.dim_a,
                               dim_u=self.config.dim_u,
                               dim_k=self.config.K,
                               A=self.init_vars['A'],  # state transition function
                               B=self.init_vars['B'],  # control matrix
                               C=self.init_vars['C'],  # Measurement function
                               R=self.init_vars['R'],  # measurement noise
                               Q=self.init_vars['Q'],  # process noise
                               y=a_seq,  # output
                               u=None,
                               mask=self.mask,
                               mu=self.init_vars['mu'],
                               Sigma=self.init_vars['Sigma'],
                               y_0=self.init_vars['a_0'],
                               alpha=self.alpha,
                               state=state_init_rnn
                               )

        # Get smoothed posterior over z
        smooth, A, B, C, alpha_plot = self.kf.smooth()

        # Get filtered posterior, used only for imputation plots
        filter, _, _, C_filter, _ = self.kf.filter()

        # Get a from the prior z (for plotting)
        a_mu_pred = tf.matmul(C, tf.expand_dims(smooth[0], 2), transpose_b=True)
        a_mu_pred_seq = tf.reshape(a_mu_pred, tf.stack((-1, self.ph_steps, self.config.dim_a)))
        if self.config.sample_z:
            a_seq = a_mu_pred_seq

        # Decoder p(x|a)
        x_hat, x_mu, x_var = self.decoder(a_seq)

        # Compute variables for generation from the model (for plotting)
        self.n_steps_gen = self.config.n_steps_gen  # We sample for this many iterations,
        self.out_gen_det = self.kf.sample_generative_tf(smooth, self.n_steps_gen, deterministic=True,
                                                        init_fixed_steps=self.config.t_init_mask)
        self.out_gen = self.kf.sample_generative_tf(smooth, self.n_steps_gen, deterministic=False,
                                                    init_fixed_steps=self.config.t_init_mask)
        self.out_gen_det_impute = self.kf.sample_generative_tf(smooth, self.test_data.timesteps, deterministic=True,
                                                               init_fixed_steps=self.config.t_init_mask)
        self.out_alpha, _, _, _ = self.alpha(self.a_prev, state=state_init_rnn, u=None, init_buffer=True, reuse=True)

        # Collect generated model variables
        self.model_vars = dict(x_hat=x_hat, x_mu=x_mu, x_var=x_var,
                               a_seq=a_seq, a_mu=a_mu, a_var=a_var, a_vae=a_vae,
                               smooth=smooth, A=A, B=B, C=C, alpha_plot=alpha_plot,
                               a_mu_pred_seq=a_mu_pred_seq, filter=filter, C_filter=C_filter)

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号