def plotKLdivergenceHeatmap(KL_all, name = None):
print("Plotting KL-divergence heatmap (activity of latent units).")
figure_name = "KL_divergence_heatmap"
if name:
figure_name = name + "/" + figure_name
figure = pyplot.figure()
axis = figure.add_subplot(1, 1, 1)
KL_array = array(KL_all)
print("Dimensions of KL-activations:")
print(KL_array.shape)
seaborn.heatmap(log(KL_array.T), xticklabels = True, yticklabels = False,
cbar = True, center = None, square = True, ax = axis)
axis.set_xlabel("Epoch")
axis.set_ylabel("$log KL(p_i||q_i)$")
data.saveFigure(figure, figure_name, no_spine = False)
analysis.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录