learning_distributed.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _setup_model_loss(self, inputs, labels, validation_inputs, validation_labels, is_chief, task_id, num_workers, is_training, scope, initial_lr=0.1, reuse=None, global_step=None, num_replicas_to_aggregate=-1):
        validation_metric = []
        validation_metric_tmp = [[] for _, _ in self.validation_metrics_def]
        self.learning_rate = tf.placeholder(
            tf.float32, shape=[], name="learning_rate_placeholder")

        losses, total_loss = self._tower_loss(
            scope, self.model, inputs, labels, is_training, reuse, is_classification=True)
        val_total_loss = self._tower_loss(
            scope, self.model, validation_inputs, validation_labels, False, True, is_classification=True)
        for i, (_, metric_function) in enumerate(self.validation_metrics_def):
            metric_score = metric_function(
                validation_labels, tf.argmax(self.validation_predictions, 1))
            validation_metric_tmp[i].append(metric_score)
        for i, (_, _) in enumerate(self.validation_metrics_def):
            validation_metric.append(sum(validation_metric_tmp[i]))
        validation_metric.append(val_total_loss)

        if is_chief:
            loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
            loss_averages_op = loss_averages.apply(losses + [total_loss])

            with tf.control_dependencies([loss_averages_op]):
                total_loss = tf.identity(total_loss)

        exp_moving_averager = tf.train.ExponentialMovingAverage(
            self.cnf.get('mv_decay', 0.9), global_step)

        variables_to_average = (
            tf.trainable_variables() + tf.moving_average_variables())

        # Create synchronous replica optimizer.
        learning_rate = self.lr_policy.batch_update(initial_lr, 0)
        opt = self._optimizer(learning_rate, optname=self.cnf.get(
            'optname', 'momentum'), **self.cnf.get('opt_kwargs', {'decay': 0.9}))
        opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=num_replicas_to_aggregate,
                                             total_num_replicas=num_workers, variable_averages=exp_moving_averager, variables_to_average=variables_to_average)
        return total_loss, opt, validation_metric
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号