def _initialize_metrics(self):
""" Initialize the model metrics """
self.metrics = {}
self.metric_values = {}
self.update_metrics = {}
self.reset_metrics = {}
for data_scope in (Data.TRAIN, Data.VALIDATE, Data.TEST):
metrics = self.collect_metrics(data_scope)
self.metrics[data_scope] = metrics
self.metric_values[data_scope] = {
name: metric['scalar']
for name, metric in iteritems(metrics)}
self.update_metrics[data_scope] = [
metric['update_op']
for metric in itervalues(metrics)]
metric_variables = []
with stats_utils.metric_scope(data_scope, graph=self.graph) as scope:
for local in tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope):
metric_variables.append(local)
self.reset_metrics[data_scope] = tf.variables_initializer(metric_variables)
评论列表
文章目录