def __init__(self, model, log_dir, histogram_freq=0, image_freq=0, audio_freq=0, write_graph=False):
super(Callback, self).__init__()
if K._BACKEND != 'tensorflow':
raise Exception('TensorBoardBatch callback only works '
'with the TensorFlow backend.')
import tensorflow as tf
self.tf = tf
import keras.backend.tensorflow_backend as KTF
self.KTF = KTF
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.image_freq = image_freq
self.audio_freq = audio_freq
self.histograms = None
self.images = None
self.write_graph = write_graph
self.iter = 0
self.scalars = []
self.images = []
self.audios = []
self.model = model
self.sess = KTF.get_session()
if self.histogram_freq != 0:
layers = self.model.layers
for layer in layers:
if hasattr(layer, 'name'):
layer_name = layer.name
else:
layer_name = layer
if hasattr(layer, 'W'):
name = '{}_W'.format(layer_name)
tf.histogram_summary(name, layer.W, collections=['histograms'])
if hasattr(layer, 'b'):
name = '{}_b'.format(layer_name)
tf.histogram_summary(name, layer.b, collections=['histograms'])
if hasattr(layer, 'output'):
name = '{}_out'.format(layer_name)
tf.histogram_summary(name, layer.output, collections=['histograms'])
if self.image_freq != 0:
tf.image_summary('input', self.model.input, max_images=2, collections=['images'])
tf.image_summary('output', self.model.output, max_images=2, collections=['images'])
if self.audio_freq != 0:
tf.audio_summary('input', self.model.input, max_outputs=1, collections=['audios'])
tf.audio_summary('output', self.model.output, max_outputs=1, collections=['audios'])
if self.write_graph:
if self.tf.__version__ >= '0.8.0':
self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph)
else:
self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph_def)
else:
self.writer = self.tf.train.SummaryWriter(self.log_dir)
评论列表
文章目录