def plot_accuracies(self, accuracies, scales=[], mode='train', fig=0):
plt.figure(fig)
plt.clf()
colors = cm.rainbow(np.linspace(0, 1, len(scales)))
l = []
names = [str(sc) for sc in scales]
for i, acc in enumerate(accuracies):
ll, = plt.plot(range(len(acc)), acc, color=colors[i])
l.append(ll)
plt.ylabel('accuracy')
plt.legend(l, names, loc=2, prop={'size': 6})
if mode == 'train':
plt.xlabel('iterations')
else:
plt.xlabel('iterations x 1000')
path = os.path.join(self.path, 'accuracies_{}.png'.format(mode))
plt.savefig(path)
评论列表
文章目录