def get_value_updater(self, data, new_mean, gamma_weighted, gamma_sum):
tf_new_differences = tf.subtract(data, tf.expand_dims(new_mean, 0))
tf_sq_dist_matrix = tf.matmul(tf.expand_dims(tf_new_differences, 2), tf.expand_dims(tf_new_differences, 1))
tf_new_covariance = tf.reduce_sum(tf_sq_dist_matrix * tf.expand_dims(tf.expand_dims(gamma_weighted, 1), 2), 0)
if self.has_prior:
tf_new_covariance = self.get_prior_adjustment(tf_new_covariance, gamma_sum)
tf_s, tf_u, _ = tf.svd(tf_new_covariance)
tf_required_eigvals = tf_s[:self.rank]
tf_required_eigvecs = tf_u[:, :self.rank]
tf_new_baseline = (tf.trace(tf_new_covariance) - tf.reduce_sum(tf_required_eigvals)) / self.tf_rest
tf_new_eigvals = tf_required_eigvals - tf_new_baseline
tf_new_eigvecs = tf.transpose(tf_required_eigvecs)
return tf.group(
self.tf_baseline.assign(tf_new_baseline),
self.tf_eigvals.assign(tf_new_eigvals),
self.tf_eigvecs.assign(tf_new_eigvecs)
)
评论列表
文章目录