def __init__(self, dim):
self._dim = dim
weights_var = tf.placeholder(
dtype=tf.float32,
shape=(None, dim),
name="weights"
)
self._f_sample = tensor_utils.compile_function(
inputs=[weights_var],
outputs=tf.multinomial(weights_var, num_samples=1)[:, 0],
)
评论列表
文章目录