def __init__(self, log_prior, log_joint, prior_sampler,
hmc, observed, latent, n_chains=25, n_temperatures=1000,
verbose=False):
# Shape of latent: [chain_axis, num_data, data dims]
# Construct the tempered objective
self.n_chains = n_chains
self.n_temperatures = n_temperatures
self.verbose = verbose
with tf.name_scope("AIS"):
self.temperature = tf.placeholder(tf.float32, shape=[],
name="temperature")
def log_fn(observed):
return log_prior(observed) * (1 - self.temperature) + \
log_joint(observed) * self.temperature
self.log_fn = log_fn
self.log_fn_val = log_fn(merge_dicts(observed, latent))
self.sample_op, self.hmc_info = hmc.sample(
log_fn, observed, latent)
self.init_latent = [tf.assign(z, z_s)
for z, z_s in zip(latent.values(),
prior_sampler.values())]
评论列表
文章目录