def prettyPlot(samps, dat, hid):
fig, ax = plt.subplots()
sz = 18
plt.rc('xtick', labelsize=sz)
plt.rc('ytick', labelsize=sz)
ax.set_xticklabels([1]+samps, fontsize=sz)
ax.set_yticklabels([1]+samps[::-1], fontsize=sz)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_xlabel('Number of Experts', fontsize=sz+2)
ax.set_ylabel('Minibatch Size', fontsize=sz+2)
ax.set_title('MOE Cell Speedup Factor', fontsize=sz+4)
#Show cell values
for i in range(len(samps)):
for j in range(len(samps)):
ax.text(i, j, str(dat[i,j])[:4], ha='center', va='center', fontsize=sz, color='white')
plt.imshow(cellTimes, cmap='viridis', norm=colors.LogNorm(vmin=cellTimes.min(), vmax=cellTimes.max()))
plt.show()
评论列表
文章目录