def fn_device_dependency(name, device=""):
"""Add control deps for name and device."""
key = name + "_" + device
outs = []
def body():
with tf.control_dependencies(fn_device_dependency_dict()[key]):
yield outs
assert outs
deps = outs
if isinstance(outs[0], list) or isinstance(outs[0], tuple):
assert len(outs) == 1
deps = outs[0]
fn_device_dependency_dict()[key] = deps
if device:
with tf.device(device):
return body()
else:
return body()
评论列表
文章目录