def sentence_length_vs_ponder_time(config,processed_data):
plotting_data = []
for data in processed_data:
steps = len(data["act_probs"])
hyp_length = sum([1 for x in data["hypothesis"] if x!="PAD"])
prem_length = sum([1 for x in data["premise"] if x!="PAD"])
avg_length = (hyp_length + prem_length)/2
correct = data["correct"]
type_class = data["class"]
plotting_data.append([steps, avg_length, correct, type_class])
plotting_data = pd.DataFrame(np.vstack(plotting_data), columns=["steps", "avg_length", "correct", "class"])
seaborn.violinplot(x="steps", y="avg_length",
hue="correct",split=True,
data=plotting_data, inner="quartile", scale="count")
plt.show()
# for class_type in [0.0,1.0,2.0]:
# fig = plt.figure()
# x_vals = [x[0] for x in plotting_data if (x[3]==class_type and x[2]==1.0)]
# y_vals = [x[1] for x in plotting_data if (x[3]==class_type and x[2]==1.0)]
# print("Class: ",class_type, "No. Correct: ", len(x_vals))
# plt.scatter(x_vals, y_vals,color="g")
#
# x_vals = [x[0] for x in plotting_data if (x[3]==class_type and x[2]==0.0)]
# y_vals = [x[1] for x in plotting_data if (x[3]==class_type and x[2]==0.0)]
# print("Class: ",class_type, "No. Incorrect: ", len(x_vals))
#
# plt.scatter(x_vals, y_vals,color="r")
#
# ax = plt.gca()
# ax.set_xlabel("ACT Steps")
# ax.set_ylabel("avg hyp/premise length")
# ax.set_title("test_title")
# plt.show()
评论列表
文章目录