def update(self, x):
# x: (chain_dims data_dims)
new_t = tf.assign(self.t, self.t + 1)
weight = (1 - self.decay) / (1 - tf.pow(self.decay, new_t))
# incr: (chain_dims data_dims)
incr = [weight * (q - mean) for q, mean in zip(x, self.mean)]
# mean: (1,...,1 data_dims)
update_mean = [mean.assign_add(
tf.reduce_mean(i, axis=self.chain_axes, keep_dims=True))
for mean, i in zip(self.mean, incr)]
# var: (1,...,1 data_dims)
new_var = [
(1 - weight) * var +
tf.reduce_mean(i * (q - mean), axis=self.chain_axes,
keep_dims=True)
for var, i, q, mean in zip(self.var, incr, x, update_mean)]
update_var = [tf.assign(var, n_var)
for var, n_var in zip(self.var, new_var)]
return update_var
评论列表
文章目录