def sample(self, dist_info):
prob = dist_info["prob"]
ids = tf.multinomial(tf.log(prob + TINY), num_samples=1)[:, 0]
onehot = tf.constant(np.eye(self.dim, dtype=np.float32))
return tf.nn.embedding_lookup(onehot, ids)
评论列表
文章目录