def backward_step_fn(self, params, inputs):
"""
Backwards step over a batch, to be used in tf.scan
:param params:
:param inputs: (batch_size, variable dimensions)
:return:
"""
mu_back, Sigma_back = params
mu_pred_tp1, Sigma_pred_tp1, mu_filt_t, Sigma_filt_t, A = inputs
# J_t = tf.matmul(tf.reshape(tf.transpose(tf.matrix_inverse(Sigma_pred_tp1), [0, 2, 1]), [-1, self.dim_z]),
# self.A)
# J_t = tf.transpose(tf.reshape(J_t, [-1, self.dim_z, self.dim_z]), [0, 2, 1])
J_t = tf.matmul(tf.transpose(A, [0, 2, 1]), tf.matrix_inverse(Sigma_pred_tp1))
J_t = tf.matmul(Sigma_filt_t, J_t)
mu_back = mu_filt_t + tf.matmul(J_t, mu_back - mu_pred_tp1)
Sigma_back = Sigma_filt_t + tf.matmul(J_t, tf.matmul(Sigma_back - Sigma_pred_tp1, J_t, adjoint_b=True))
return mu_back, Sigma_back
评论列表
文章目录