def trainable_variables_on_device(self, rel_device_num, abs_device_num,
writable):
"""Return the set of trainable variables on the specified device.
Args:
rel_device_num: local worker device index.
abs_device_num: global graph device index.
writable: whether the returned variables is writable or read-only.
Returns:
Return the set of trainable variables on the specified device.
"""
del abs_device_num
params_refs = tf.trainable_variables()
if writable:
return params_refs
params = []
for param in params_refs:
var_name = param.name.split(':')[0]
_, var_get_op = self.variable_mgr.staging_vars_on_devices[rel_device_num][
var_name]
params.append(var_get_op)
return params
评论列表
文章目录