def plot_filter_heatmap(weights, filename=None):
param_range = abs(weights).max()
fig, ax = plt.subplots(figsize=(weights.shape[1], weights.shape[0]))
sns.heatmap(weights, cmap='RdYlBu_r', linewidths=0.2, vmin=-param_range,
vmax=param_range, ax=ax)
ax.set_xticklabels(range(1, weights.shape[1] + 1))
labels = [ALPHABET_R[i] for i in reversed(range(weights.shape[0]))]
ax.set_yticklabels(labels, rotation='horizontal', size=10)
if filename:
plt.savefig(filename)
plt.close()
评论列表
文章目录