callbacks.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:deep-learning-keras-projects 作者: jasmeetsb 项目源码 文件源码
def set_model(self, model):
        self.model = model
        self.sess = K.get_session()
        if self.histogram_freq and self.merged is None:
            for layer in self.model.layers:

                for weight in layer.weights:
                    if hasattr(tf, 'histogram_summary'):
                        tf.histogram_summary(weight.name, weight)
                    else:
                        tf.summary.histogram(weight.name, weight)

                    if self.write_images:
                        w_img = tf.squeeze(weight)

                        shape = w_img.get_shape()
                        if len(shape) > 1 and shape[0] > shape[1]:
                            w_img = tf.transpose(w_img)

                        if len(shape) == 1:
                            w_img = tf.expand_dims(w_img, 0)

                        w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1)

                        if hasattr(tf, 'image_summary'):
                            tf.image_summary(weight.name, w_img)
                        else:
                            tf.summary.image(weight.name, w_img)

                if hasattr(layer, 'output'):
                    if hasattr(tf, 'histogram_summary'):
                        tf.histogram_summary('{}_out'.format(layer.name),
                                             layer.output)
                    else:
                        tf.summary.histogram('{}_out'.format(layer.name),
                                             layer.output)

        if hasattr(tf, 'merge_all_summaries'):
            self.merged = tf.merge_all_summaries()
        else:
            self.merged = tf.summary.merge_all()

        if self.write_graph:
            if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
                self.writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
            elif parse_version(tf.__version__) >= parse_version('0.8.0'):
                self.writer = tf.train.SummaryWriter(self.log_dir,
                                                     self.sess.graph)
            else:
                self.writer = tf.train.SummaryWriter(self.log_dir,
                                                     self.sess.graph_def)
        else:
            if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
                self.writer = tf.summary.FileWriter(self.log_dir)
            else:
                self.writer = tf.train.SummaryWriter(self.log_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号