callbacks.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:neural_style 作者: metaflow-ai 项目源码 文件源码
def __init__(self, model, log_dir, histogram_freq=0, image_freq=0, audio_freq=0, write_graph=False):
        super(Callback, self).__init__()
        if K._BACKEND != 'tensorflow':
            raise Exception('TensorBoardBatch callback only works '
                            'with the TensorFlow backend.')
        import tensorflow as tf
        self.tf = tf
        import keras.backend.tensorflow_backend as KTF
        self.KTF = KTF

        self.log_dir = log_dir
        self.histogram_freq = histogram_freq
        self.image_freq = image_freq
        self.audio_freq = audio_freq
        self.histograms = None
        self.images = None
        self.write_graph = write_graph
        self.iter = 0
        self.scalars = []
        self.images = []
        self.audios = []

        self.model = model
        self.sess = KTF.get_session()

        if self.histogram_freq != 0:
            layers = self.model.layers
            for layer in layers:
                if hasattr(layer, 'name'):
                    layer_name = layer.name
                else:
                    layer_name = layer

                if hasattr(layer, 'W'):
                    name = '{}_W'.format(layer_name)
                    tf.histogram_summary(name, layer.W, collections=['histograms'])
                if hasattr(layer, 'b'):
                    name = '{}_b'.format(layer_name)
                    tf.histogram_summary(name, layer.b, collections=['histograms'])
                if hasattr(layer, 'output'):
                    name = '{}_out'.format(layer_name)
                    tf.histogram_summary(name, layer.output, collections=['histograms'])

        if self.image_freq != 0:
            tf.image_summary('input', self.model.input, max_images=2, collections=['images'])
            tf.image_summary('output', self.model.output, max_images=2, collections=['images'])

        if self.audio_freq != 0:
            tf.audio_summary('input', self.model.input, max_outputs=1, collections=['audios'])
            tf.audio_summary('output', self.model.output, max_outputs=1, collections=['audios'])

        if self.write_graph:
            if self.tf.__version__ >= '0.8.0':
                self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph)
            else:
                self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph_def)
        else:
            self.writer = self.tf.train.SummaryWriter(self.log_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号