def test_identity(tensor_shape):
if len(tensor_shape) > 2:
with pytest.raises(Exception):
_runner(initializations.identity, tensor_shape,
target_mean=1./SHAPE[0], target_max=1.)
else:
_runner(initializations.identity, tensor_shape,
target_mean=1./SHAPE[0], target_max=1.)
评论列表
文章目录