def build_model(self):
# Encoder q(a|x)
a_seq, a_mu, a_var = self.encoder(self.x)
a_vae = a_seq
# Initial state for the alpha RNN
dummy_lstm = BasicLSTMCell(self.config.alpha_units * 2 if self.config.learn_u else self.config.alpha_units)
state_init_rnn = dummy_lstm.zero_state(self.config.batch_size, tf.float32)
# Initialize Kalman filter (LGSSM)
self.kf = KalmanFilter(dim_z=self.config.dim_z,
dim_y=self.config.dim_a,
dim_u=self.config.dim_u,
dim_k=self.config.K,
A=self.init_vars['A'], # state transition function
B=self.init_vars['B'], # control matrix
C=self.init_vars['C'], # Measurement function
R=self.init_vars['R'], # measurement noise
Q=self.init_vars['Q'], # process noise
y=a_seq, # output
u=None,
mask=self.mask,
mu=self.init_vars['mu'],
Sigma=self.init_vars['Sigma'],
y_0=self.init_vars['a_0'],
alpha=self.alpha,
state=state_init_rnn
)
# Get smoothed posterior over z
smooth, A, B, C, alpha_plot = self.kf.smooth()
# Get filtered posterior, used only for imputation plots
filter, _, _, C_filter, _ = self.kf.filter()
# Get a from the prior z (for plotting)
a_mu_pred = tf.matmul(C, tf.expand_dims(smooth[0], 2), transpose_b=True)
a_mu_pred_seq = tf.reshape(a_mu_pred, tf.stack((-1, self.ph_steps, self.config.dim_a)))
if self.config.sample_z:
a_seq = a_mu_pred_seq
# Decoder p(x|a)
x_hat, x_mu, x_var = self.decoder(a_seq)
# Compute variables for generation from the model (for plotting)
self.n_steps_gen = self.config.n_steps_gen # We sample for this many iterations,
self.out_gen_det = self.kf.sample_generative_tf(smooth, self.n_steps_gen, deterministic=True,
init_fixed_steps=self.config.t_init_mask)
self.out_gen = self.kf.sample_generative_tf(smooth, self.n_steps_gen, deterministic=False,
init_fixed_steps=self.config.t_init_mask)
self.out_gen_det_impute = self.kf.sample_generative_tf(smooth, self.test_data.timesteps, deterministic=True,
init_fixed_steps=self.config.t_init_mask)
self.out_alpha, _, _, _ = self.alpha(self.a_prev, state=state_init_rnn, u=None, init_buffer=True, reuse=True)
# Collect generated model variables
self.model_vars = dict(x_hat=x_hat, x_mu=x_mu, x_var=x_var,
a_seq=a_seq, a_mu=a_mu, a_var=a_var, a_vae=a_vae,
smooth=smooth, A=A, B=B, C=C, alpha_plot=alpha_plot,
a_mu_pred_seq=a_mu_pred_seq, filter=filter, C_filter=C_filter)
return self
评论列表
文章目录