def test_callback(self):
a, b, c = tensor.scalars('abc')
f = function([a, b, c], (a + b) + c,
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(callback=self.callback)))
f(1, 2, 3)
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort())
f(1, 2, 3)
assert (sum(self.n_callbacks.values()) ==
len(f.maker.fgraph.toposort()) * 2)
评论列表
文章目录