def local_assert_no_cpu_op(node):
if (all([var.owner and isinstance(var.owner.op, HostFromGpu)
for var in node.inputs]) and
any([[c for c in var.clients if isinstance(c[0].op, GpuFromHost)]
for var in node.outputs])):
if config.assert_no_cpu_op == "warn":
_logger.warning(("CPU Op %s is detected in the computation "
"graph") % node)
elif config.assert_no_cpu_op == "raise":
raise AssertionError("The Op %s is on CPU." % node)
elif config.assert_no_cpu_op == "pdb":
pdb.set_trace()
# Register the local_assert_no_cpu_op:
评论列表
文章目录