def variable_device(device, name):
"""Fix the variable device to colocate its ops."""
if callable(device):
var_name = tf.get_variable_scope().name + '/' + name
var_def = tf.NodeDef(name=var_name, op='Variable')
device = device(var_def)
if device is None:
device = ''
return device
评论列表
文章目录