def plot_debug_grad(debug, tag, fold_exp, trg):
plt.close("all")
# f = plt.figure(figsize=(15, 10.8), dpi=300)
nbr_rows = int(len(debug["grad_sup"][0])/2)
f, axs = plt.subplots(nbr_rows, 2, sharex=True, sharey=False,
figsize=(15, 12.8), dpi=300)
if trg == "sup":
grad = np.array(debug["grad_sup"])
elif trg == "hint":
grad = np.array(debug["grad_hint"])
print grad.shape, trg
j = 0
for i in range(0, nbr_rows*2, 2):
w_vl = grad[:, i]
b_vl = grad[:, i+1]
axs[j, 0].plot(w_vl, label=trg)
axs[j, 0].set_title("w"+str(j))
axs[j, 1].plot(b_vl, label=trg)
axs[j, 1].set_title("b"+str(j))
axs[j, 0].grid(True)
axs[j, 1].grid(True)
j += 1
f.suptitle("Grad sup/hint:" + tag, fontsize=8)
plt.legend()
f.savefig(fold_exp+"/grad_" + trg + ".png", bbox_inches='tight')
plt.close("all")
del f
tools.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录