def one_hot_comparison(hot_axes, axes, C):
"""
TODO.
Arguments:
hot_axes: TODO
axes: TODO
"""
u = rng.random_integers(0, C.length - 1, axes, dtype=np.int8)
u_p = ng.placeholder(axes, dtype=u.dtype)
v = np.zeros(hot_axes.lengths, dtype=np.float32)
udxiter = np.nditer(u, flags=['multi_index'])
for uiter in udxiter:
vindex = [int(uiter)]
vindex.extend(udxiter.multi_index)
v[tuple(vindex)] = 1
with executor(ng.one_hot(u_p, axis=C), u_p) as ex:
v_t = ex(u)
ng.testing.assert_allclose(v_t, v)
评论列表
文章目录