def test_grad(self):
for format in sparse.sparse_formats:
for i_dtype in sparse.float_dtypes:
for o_dtype in tensor.float_dtypes:
if o_dtype == 'float16':
# Don't test float16 output.
continue
_, data = sparse_random_inputs(
format,
shape=(4, 7),
out_dtype=i_dtype)
eps = None
if o_dtype == 'float32':
eps = 1e-2
verify_grad_sparse(Cast(o_dtype), data, eps=eps)
评论列表
文章目录