test_execution.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def one_hot_comparison(hot_axes, axes, C):
    """
    TODO.

    Arguments:
      hot_axes: TODO
      axes: TODO
    """
    u = rng.random_integers(0, C.length - 1, axes, dtype=np.int8)
    u_p = ng.placeholder(axes, dtype=u.dtype)
    v = np.zeros(hot_axes.lengths, dtype=np.float32)
    udxiter = np.nditer(u, flags=['multi_index'])
    for uiter in udxiter:
        vindex = [int(uiter)]
        vindex.extend(udxiter.multi_index)
        v[tuple(vindex)] = 1

    with executor(ng.one_hot(u_p, axis=C), u_p) as ex:
        v_t = ex(u)
        ng.testing.assert_allclose(v_t, v)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号