def variables_to_restore(self, moving_avg_variables=None):
""""""
name_map = {}
if moving_avg_variables is None:
moving_avg_variables = tf.trainable_variables()
moving_avg_variables += tf.moving_average_variables()
# Remove duplicates
moving_avg_variables = set(moving_avg_variables)
# Collect all the variables with moving average,
for v in moving_avg_variables:
name_map[self.average_name(v)] = v
# Make sure we restore variables without moving average as well.
for v in list(set(tf.all_variables()) - moving_avg_variables):
if v.op.name not in name_map:
name_map[v.op.name] = v
return name_map
#===============================================================
评论列表
文章目录