def update_weights(self, train_op):
"""Updates the model weights.
This function must be called on at least one worker after `minimize`.
In distributed training this call can be omitted on non-chief workers to
speed up training.
Args:
train_op: The operation returned by the `minimize` call.
Returns:
An Operation that updates the model weights.
"""
with ops.control_dependencies([train_op]):
update_ops = []
# Copy over unshrinked weights to user provided variables.
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
update_ops.append(var.assign(slot_var))
# Apply proximal step.
with ops.control_dependencies(update_ops):
update_ops = []
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
with ops.device(var.device):
# pylint: disable=protected-access
update_ops.append(
gen_sdca_ops._sdca_shrink_l1(
self._convert_n_to_tensor(
[var], as_ref=True),
l1=self._symmetric_l1_regularization(),
l2=self._symmetric_l2_regularization()))
return control_flow_ops.group(*update_ops)
评论列表
文章目录