def sample_sym(self, dist_info):
probs = dist_info["prob"]
samples = tf.multinomial(tf.log(probs + 1e-8), num_samples=1)[:, 0]
return tf.nn.embedding_lookup(np.eye(self.dim, dtype=np.float32), samples)
文章目录