def _adapt_mass(self, t, num_chain_dims):
ewmv = ExponentialWeightedMovingVariance(
self.mass_decay, self.data_shapes, num_chain_dims)
new_mass = tf.cond(self.adapt_mass,
lambda: ewmv.get_updated_precision(self.q),
lambda: ewmv.precision())
if not isinstance(new_mass, list):
new_mass = [new_mass]
# print('New mass is = {}'.format(new_mass))
# TODO incorrect shape?
# print('New mass={}'.format(new_mass))
# print('q={}, NMS={}'.format(self.q[0].get_shape(),
# new_mass[0].get_shape()))
with tf.control_dependencies(new_mass):
current_mass = tf.cond(
tf.less(tf.to_int32(t), self.mass_collect_iters),
lambda: [tf.ones(shape) for shape in self.data_shapes],
lambda: new_mass)
if not isinstance(current_mass, list):
current_mass = [current_mass]
return current_mass
评论列表
文章目录