def gumbel_softmax_sample(logits, temperature): """ Draw a sample from the Gumbel-Softmax distribution""" y = tf.add(logits,sample_gumbel(tf.shape(logits))) return tf.nn.softmax( tf.div(y, temperature))