def __init__(self, energy_fn, prior, std=1.0,
inter_op_parallelism_threads=1, intra_op_parallelism_threads=1):
self.energy_fn = energy_fn
self.prior = prior
self.z = self.energy_fn.z
def fn(z, x):
z_ = z + tf.random_normal(tf.shape(self.z), 0.0, std)
accept = metropolis_hastings_accept(
energy_prev=energy_fn(z),
energy_next=energy_fn(z_)
)
return tf.where(accept, z_, z)
self.steps = tf.placeholder(tf.int32, [])
elems = tf.zeros([self.steps])
self.z_ = tf.scan(
fn, elems, self.z, back_prop=False
)
self.sess = tf.Session(
config=tf.ConfigProto(
inter_op_parallelism_threads=inter_op_parallelism_threads,
intra_op_parallelism_threads=intra_op_parallelism_threads
)
)
self.sess.run(tf.global_variables_initializer())
评论列表
文章目录