def generate_speed_graph(data, filename="als_speed.png", keys=['gpu', 'cg2', 'cg3', 'cholesky'],
labels=None, colours=None):
labels = labels or {}
colours = colours or {}
seaborn.set()
fig, ax = plt.subplots()
factors = data['factors']
for key in keys:
ax.plot(factors, data[key],
color=colours.get(key, COLOURS.get(key)),
marker='o', markersize=6)
ax.text(factors[-1] + 5, data[key][-1], labels.get(key, LABELS[key]), fontsize=10)
ax.set_ylabel("Seconds per Iteration")
ax.set_xlabel("Factors")
plt.savefig(filename, bbox_inches='tight', dpi=300)
评论列表
文章目录