def plot_debug_ratio_grad(debug, fold_exp, r="h/s"):
plt.close("all")
# f = plt.figure(figsize=(15, 10.8), dpi=300)
nbr_rows = int(len(debug["grad_sup"][0])/2)
f, axs = plt.subplots(nbr_rows, 2, sharex=True, sharey=False,
figsize=(15, 12.8), dpi=300)
grads = np.array(debug["grad_sup"])
gradh = np.array(debug["grad_hint"])
if gradh.size != grads.size:
print "Can't calculate the ratio. It looks like you divided the " +\
"hint batch..."
return 0
print gradh.shape, grads.shape
j = 0
for i in range(0, nbr_rows*2, 2):
w_vls = grads[:, i]
b_vls = grads[:, i+1]
w_vl_h = gradh[:, i]
b_vlh = gradh[:, i+1]
if r == "h/s":
ratio_w = np.divide(w_vl_h, w_vls)
ratio_b = np.divide(b_vlh, b_vls)
elif r == "s/h":
ratio_w = np.divide(w_vls, w_vl_h)
ratio_b = np.divide(b_vls, b_vlh)
else:
raise ValueError("Either h/s or s/h.")
axs[j, 0].plot(ratio_w, label=r)
axs[j, 0].set_title("w"+str(j))
axs[j, 1].plot(ratio_b, label=r)
axs[j, 1].set_title("b"+str(j))
axs[j, 0].grid(True)
axs[j, 1].grid(True)
j += 1
f.suptitle("Ratio gradient: " + r, fontsize=8)
plt.legend()
f.savefig(fold_exp+"/ratio_grad_" + r.replace("/", "-") + ".png",
bbox_inches='tight')
plt.close("all")
del f
tools.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录