def plot(dims, sequence, factorization):
import matplotlib
matplotlib.use('Agg') # NOQA
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
plt.ylabel("Speed improvement")
plt.xlabel("Size of embedding layers")
plt.title("Fitting speed (1.0 = no change)")
plt.xscale('log')
plt.plot(dims,
1.0 / sequence,
label='Sequence model')
plt.plot(dims,
1.0 / factorization,
label='Factorization model')
plt.legend(loc='lower right')
plt.savefig('speed.png')
plt.close()
评论列表
文章目录