def test_callback_with_ifelse(self):
a, b, c = tensor.scalars('abc')
f = function([a, b, c], ifelse(a, 2 * b, 2 * c),
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert self.n_callbacks['IfElse'] == 2
评论列表
文章目录