def combinations(s_data, subset_size, total_size=None, name=None):
assert isinstance(subset_size, int)
assert subset_size > 0
if total_size is None:
total_size = s_data.get_shape().as_list()[0]
if total_size is None:
raise ValueError(
"tensor size on axis 0 is unknown,"
" please supply 'total_size'")
else:
assert isinstance(total_size, int)
assert subset_size <= total_size
c_combs = tf.constant(
list(itertools.combinations(range(total_size), subset_size)),
dtype=hparams.INTX,
name=('combs' if name is None else name))
return tf.gather(s_data, c_combs)
评论列表
文章目录