def multinomial_2d(x):
"""Samples from a multinomial distribution from 2D Tensor.
Args:
x: Tensor with shape (batch_size, classes)
Returns:
Tensor with shape (batch_size), sampled from `classes`.
"""
a = tf.shape(x)[0]
m = tf.multinomial(x, 1)
return tf.reshape(m, (a,))
评论列表
文章目录