def set_model(self, model):
self.model = model
self.sess = K.get_session()
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
if hasattr(tf, 'histogram_summary'):
tf.histogram_summary(weight.name, weight)
else:
tf.summary.histogram(weight.name, weight)
if self.write_images:
w_img = tf.squeeze(weight)
shape = w_img.get_shape()
if len(shape) > 1 and shape[0] > shape[1]:
w_img = tf.transpose(w_img)
if len(shape) == 1:
w_img = tf.expand_dims(w_img, 0)
w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1)
if hasattr(tf, 'image_summary'):
tf.image_summary(weight.name, w_img)
else:
tf.summary.image(weight.name, w_img)
if hasattr(layer, 'output'):
if hasattr(tf, 'histogram_summary'):
tf.histogram_summary('{}_out'.format(layer.name),
layer.output)
else:
tf.summary.histogram('{}_out'.format(layer.name),
layer.output)
if hasattr(tf, 'merge_all_summaries'):
self.merged = tf.merge_all_summaries()
else:
self.merged = tf.summary.merge_all()
if self.write_graph:
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
self.writer = tf.summary.FileWriter(self.log_dir,
self.sess.graph)
elif parse_version(tf.__version__) >= parse_version('0.8.0'):
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph)
else:
self.writer = tf.train.SummaryWriter(self.log_dir,
self.sess.graph_def)
else:
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
self.writer = tf.summary.FileWriter(self.log_dir)
else:
self.writer = tf.train.SummaryWriter(self.log_dir)
callbacks.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录