def plot_dist(train_y,dev_y,test_y):
import seaborn as sns
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='Times-Roman')
sns.set_style(style='white')
color = sns.color_palette("Set2", 10)
fig = plt.figure(figsize=(8,12))
ax1 = fig.add_subplot(3, 1, 1)
# plt.title("Label distribution",fontsize=20)
sns.distplot(train_y,kde=False,label='Training', hist=True, norm_hist=True,color="blue")
ax1.set_xlabel("Answer")
ax1.set_ylabel("Frequency")
ax1.set_xlim([0,500])
plt.legend(loc='best')
ax2 = fig.add_subplot(3, 1, 2)
sns.distplot(dev_y,kde=False,label='Validation', hist=True, norm_hist=True,color="green")
ax2.set_xlabel("Answer")
ax2.set_ylabel("Frequency")
ax2.set_xlim([0,500])
plt.legend(loc='best')
ax3 = fig.add_subplot(3, 1, 3)
sns.distplot(test_y,kde=False,label='Test', hist=True, norm_hist=True,color="red")
ax3.set_xlabel("Answer")
ax3.set_ylabel("Frequency")
ax3.set_xlim([0,500])
plt.legend(loc='best')
plt.savefig('checkpoints/label_dist.pdf', format='pdf', dpi=300)
plt.show()
评论列表
文章目录