def test_grad_argmax(self):
data = rand(2, 3)
n = as_tensor_variable(data)
# test grad of argmax
utt.verify_grad(lambda v: argmax(v, axis=-1), [data])
utt.verify_grad(lambda v: argmax(v, axis=[0]), [data])
utt.verify_grad(lambda v: argmax(v, axis=[1]), [data])
utt.verify_grad(lambda v: argmax(v.flatten()), [data])
try:
grad(argmax(n, axis=-1), n)
raise Exception('Expected an error')
except TypeError:
pass
评论列表
文章目录