def plot(self, samples, title):
""" Plot core process time of chainer and pytorch. """
batch_sizes = [2**m for m in range(self.max_batch_index + 1)]
process_time = {'chainer': {'mean': [], 'std': []},
'pytorch': {'mean': [], 'std': []}}
for batch_size in tqdm(batch_sizes, desc='batch size'):
for name, process in tqdm(self.process.items(), desc='testers'):
# set batch size.
process.set_batch_size(batch_size)
compute_time = []
# get compute time.
for index in trange(samples, desc='samples'):
start = time.time()
process.run(self.only_inference)
compute_time.append(time.time() - start)
# calculate mean and std.
process_time[name]['mean'].append(np.mean(compute_time))
process_time[name]['std'].append(np.std(compute_time))
# plot core process time of each batch size.
for name, p_t in process_time.items():
plt.errorbar(batch_sizes, p_t['mean'], yerr=p_t['std'], label=name)
# plot settings.
plt.title(title)
plt.legend(loc='lower right')
plt.xlabel('batch size')
plt.ylabel('core process time [sec]')
# save plot.
if self.debug:
plt.show()
else:
filename = '_'.join(title.split(' ')) + '.png'
plt.savefig(os.path.join(self.output, filename))
core_process_time_evaluator.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录