def __init__(self, dim_z, dim_y, dim_u=0, dim_k=1, **kwargs):
self.dim_z = dim_z
self.dim_y = dim_y
self.dim_u = dim_u
self.dim_k = dim_k
# Initializer for identity matrix
self.eye_init = lambda shape, dtype=np.float32: np.eye(*shape, dtype=dtype)
# Pop all variables
init = kwargs.pop('mu', np.zeros((dim_z, ), dtype=np.float32))
self.mu = tf.get_variable('mu', initializer=init, trainable=False) # state
init = kwargs.pop('Sigma', self.eye_init((dim_z, dim_z))).astype(np.float32)
self.Sigma = tf.get_variable('Sigma', initializer=init, trainable=False) # uncertainty covariance
init = kwargs.pop('y_0', np.zeros((dim_y,))).astype(np.float32)
self.y_0 = tf.get_variable('y_0', initializer=init) # initial output
init = kwargs.pop('A', self.eye_init((dim_z, dim_z)))
self.A = tf.get_variable('A', initializer=init)
init = kwargs.pop('B', self.eye_init((dim_z, dim_u))).astype(np.float32)
self.B = tf.get_variable('B', initializer=init) # control transition matrix
init = kwargs.pop('Q', self.eye_init((dim_z, dim_z))).astype(np.float32)
self.Q = tf.get_variable('Q', initializer=init, trainable=False) # process uncertainty
init = kwargs.pop('C', self.eye_init((dim_y, dim_z))).astype(np.float32)
self.C = tf.get_variable('C', initializer=init) # Measurement function
init = kwargs.pop('R', self.eye_init((dim_y, dim_y))).astype(np.float32)
self.R = tf.get_variable('R', initializer=init, trainable=False) # state uncertainty
self._alpha_sq = tf.constant(1., dtype=tf.float32) # fading memory control
self.M = 0 # process-measurement cross correlation
# identity matrix
self._I = tf.constant(self.eye_init((dim_z, dim_z)), name='I')
# Get variables that are possibly defined with tensors
self.y = kwargs.pop('y', None)
if self.y is None:
self.y = tf.placeholder(tf.float32, shape=(None, None, dim_y), name='y')
self.u = kwargs.pop('u', None)
if self.u is None:
self.u = tf.placeholder(tf.float32, shape=(None, None, dim_u), name='u')
self.mask = kwargs.pop('mask', None)
if self.mask is None:
self.mask = tf.placeholder(tf.float32, shape=(None, None), name='mask')
self.alpha = kwargs.pop('alpha', None)
self.state = kwargs.pop('state', None)
self.log_likelihood = None
评论列表
文章目录