benchmark_cnn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:benchmarks 作者: tensorflow 项目源码 文件源码
def run(self):
    """Run the benchmark task assigned to this process.

    Returns:
      Dictionary of statistics for training or eval.
    Raises:
       ValueError: unrecognized job name.
    """
    if self.params.job_name == 'ps':
      log_fn('Running parameter server %s' % self.task_index)
      self.cluster_manager.join_server()
      return {}

    # For distributed_all_reduce with multiple workers, drive
    # from a separate controller process.
    if self.params.variable_update == 'distributed_all_reduce':
      if self.params.job_name == 'worker':
        log_fn('Starting worker %s' % self.task_index)
        self.cluster_manager.join_server()
        return
      elif self.params.job_name and self.params.job_name != 'controller':
        raise ValueError('unrecognized job name: %s' % self.params.job_name)

    with tf.Graph().as_default():
      if self.params.eval:
        return self._eval_cnn()
      else:
        return self._benchmark_cnn()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号