core_process_time_evaluator.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepPoseComparison 作者: ynaka81 项目源码 文件源码
def plot(self, samples, title):
        """ Plot core process time of chainer and pytorch. """
        batch_sizes = [2**m for m in range(self.max_batch_index + 1)]
        process_time = {'chainer': {'mean': [], 'std': []},
                        'pytorch': {'mean': [], 'std': []}}
        for batch_size in tqdm(batch_sizes, desc='batch size'):
            for name, process in tqdm(self.process.items(), desc='testers'):
                # set batch size.
                process.set_batch_size(batch_size)
                compute_time = []
                # get compute time.
                for index in trange(samples, desc='samples'):
                    start = time.time()
                    process.run(self.only_inference)
                    compute_time.append(time.time() - start)
                # calculate mean and std.
                process_time[name]['mean'].append(np.mean(compute_time))
                process_time[name]['std'].append(np.std(compute_time))
        # plot core process time of each batch size.
        for name, p_t in process_time.items():
            plt.errorbar(batch_sizes, p_t['mean'], yerr=p_t['std'], label=name)
        # plot settings.
        plt.title(title)
        plt.legend(loc='lower right')
        plt.xlabel('batch size')
        plt.ylabel('core process time [sec]')
        # save plot.
        if self.debug:
            plt.show()
        else:
            filename = '_'.join(title.split(' ')) + '.png'
            plt.savefig(os.path.join(self.output, filename))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号