ops.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def combinations(s_data, subset_size, total_size=None, name=None):
    assert isinstance(subset_size, int)
    assert subset_size > 0
    if total_size is None:
        total_size = s_data.get_shape().as_list()[0]

    if total_size is None:
        raise ValueError(
            "tensor size on axis 0 is unknown,"
            " please supply 'total_size'")
    else:
        assert isinstance(total_size, int)
        assert subset_size <= total_size

    c_combs = tf.constant(
        list(itertools.combinations(range(total_size), subset_size)),
        dtype=hparams.INTX,
        name=('combs' if name is None else name))

    return tf.gather(s_data, c_combs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号