def log_combination(n, ks):
"""
Compute the log combination function.
.. math::
\\log \\binom{n}{k_1, k_2, \\dots} = \\log n! - \\sum_{i}\\log k_i!
:param n: A N-D `float` Tensor. Can broadcast to match `ks[:-1]`.
:param ks: A (N + 1)-D `float` Tensor. Each slice `[i, j, ..., k, :]` is
a vector of `[k_1, k_2, ...]`.
:return: A N-D Tensor of type same as `n`.
"""
return tf.lgamma(n + 1) - tf.reduce_sum(tf.lgamma(ks + 1), axis=-1)
评论列表
文章目录