def assign_sub(self, delta, name=None):
"""Mimic the updates to the variable.
Args:
delta: is pushed into a staging buffer and will be pumped later.
name: currently ignored; names of ops and the StagingArea are
computed without using this pass name.
Returns:
The actual updates. The colocation constraint will be reapplied.
"""
# This parameter is ignored: the StagingArea only supports setting
# the shared name, not the names of individual ops it uses.
del name
# colocate_with(None, True) clears the colocation constraints.
# Push the delta into a staging buffer.
with ops.colocate_with(None, True), tf.device(self.var_stage_get.device):
delta_staging_area = data_flow_ops.StagingArea(
[self.var_stage_get.dtype], shapes=[self.var_stage_get.shape])
delta_put_op = delta_staging_area.put([delta])
self.variable_mgr.staging_delta_ops.append(delta_put_op)
delta_get_op = delta_staging_area.get()[0]
# Return the actual updates. The colocation constraint will be reapplied.
return self.real_var.assign_sub(delta_get_op)
评论列表
文章目录