def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
"""Create a saver swapping moving averages and variables.
You should use this saver during training. It will save the moving averages
of the trained parameters under the original parameter names. For
evaluations or inference you should use a regular saver and it will
automatically use the moving averages for the trained variable.
You must call this function after all variables have been created and after
you have called Optimizer.minimize().
Args:
var_list: List of variables to save, as per `Saver()`.
If set to None, will save all the variables that have been
created before this call.
name: The name of the saver.
**kwargs: Keyword arguments of `Saver()`.
Returns:
A `tf.train.Saver` object.
Raises:
RuntimeError: If apply_gradients or minimize has not been called before.
"""
if self._variable_map is None:
raise RuntimeError('Must call apply_gradients or minimize before '
'creating the swapping_saver')
if var_list is None:
var_list = tf.global_variables()
if not isinstance(var_list, dict):
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
# Now swap variables and moving averages
swapped_var_list = {}
for k, v in six.iteritems(var_list):
v_swap = self._variable_map.get(v.op.name, None)
if v_swap:
swapped_var_list[k] = v_swap
else:
swapped_var_list[k] = v
# Build the swapping saver.
return saver.Saver(swapped_var_list, name=name, **kwargs)
评论列表
文章目录