def on_train_begin(self, logs={}):
for layer in self.get_trainable_layers():
for param in self.parameters:
if any(w for w in layer.weights if param in w.name.split("_")):
name = layer.name + "_" + param
self.layers_stats[name]["values"] = numpy.asarray(
[]).ravel()
for s in self.stats:
self.layers_stats[name][s] = []
# plt.style.use('ggplot')
plt.ion() # set plot to animated
width = 3 * (1 + len(self.stats))
height = 2 * len(self.layers_stats)
self.fig = plt.figure(
figsize=(width, height)) # width, height in inches
# sns.set_style("whitegrid")
# self.draw_plot()
评论列表
文章目录