def save_plot(niters, loss, args):
print('Saving training loss-iteration figure...')
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
name = 'Train-{}_hs-{}_lr-{}_bs-{}'.format(args.train_file, args.hs,
args.lr, args.batch_size)
plt.title(name)
plt.plot(niters, loss)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.savefig(name + '.jpg')
print('{} saved!'.format(name + '.jpg'))
except ImportError:
print('matplotlib not installed and no figure is saved.')
评论列表
文章目录