def weight_norm_histogram(rbm, show_plot=False, filename=None):
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots()
for l in range(rbm.num_weights):
num_inputs = rbm.weights[l].shape[0]
norm = be.to_numpy_array(be.norm(rbm.weights[l].W(), axis=0) / sqrt(num_inputs))
sns.distplot(norm, ax=ax, label=str(l))
ax.legend()
if show_plot:
plt.show(fig)
if filename is not None:
fig.savefig(filename)
plt.close(fig)
评论列表
文章目录