def test_grad_argmin(self):
data = rand(2, 3)
n = as_tensor_variable(data)
n.name = 'n'
# test grad of argmin
utt.verify_grad(lambda v: argmin(v, axis=-1), [data])
utt.verify_grad(lambda v: argmin(v, axis=[0]), [data])
utt.verify_grad(lambda v: argmin(v, axis=[1]), [data])
utt.verify_grad(lambda v: argmin(v.flatten()), [data])
try:
cost = argmin(n, axis=-1)
cost.name = None
g = grad(cost, n)
raise Exception('Expected an error')
except TypeError:
pass
评论列表
文章目录