def sample_sym(self, dist_info): probs = dist_info["prob"] samples = tf.multinomial(tf.log(probs + 1e-8), num_samples=1)[:, 0] return tf.nn.embedding_lookup(np.eye(self.dim, dtype=np.float32), samples)