4-mnist-using-contrib.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:npfl114 作者: ufal 项目源码 文件源码
def __init__(self, logdir, experiment, threads):
        # Construct the graph
        with tf.name_scope("inputs"):
            self.images = tf.placeholder(tf.float32, [None, WIDTH, HEIGHT, 1], name="images")
            self.labels = tf.placeholder(tf.int64, [None], name="labels")
            flattened_images = layers.flatten(self.images)

        hidden_layer = layers.fully_connected(flattened_images, num_outputs=HIDDEN, activation_fn=tf.nn.relu, scope="hidden_layer")
        output_layer = layers.fully_connected(hidden_layer, num_outputs=LABELS, activation_fn=None, scope="output_layer")

        loss = losses.sparse_softmax_cross_entropy(output_layer, self.labels, scope="loss")
        self.training = layers.optimize_loss(loss, None, None, tf.train.AdamOptimizer(), summaries=['loss', 'gradients', 'gradient_norm'], name='training')

        with tf.name_scope("accuracy"):
            predictions = tf.argmax(output_layer, 1, name="predictions")
            accuracy = metrics.accuracy(predictions, self.labels)
            tf.scalar_summary("training/accuracy", accuracy)

        with tf.name_scope("confusion_matrix"):
            confusion_matrix = metrics.confusion_matrix(predictions, self.labels, weights=tf.not_equal(predictions, self.labels), dtype=tf.float32)
            confusion_image = tf.reshape(confusion_matrix, [1, LABELS, LABELS, 1])

        # Summaries
        self.summaries = {'training': tf.merge_all_summaries() }
        for dataset in ["dev", "test"]:
            self.summaries[dataset] = tf.merge_summary([tf.scalar_summary(dataset + "/accuracy", accuracy),
                                                        tf.image_summary(dataset + "/confusion_matrix", confusion_image)])

        # Create the session
        self.session = tf.Session(config=tf.ConfigProto(inter_op_parallelism_threads=threads,
                                                        intra_op_parallelism_threads=threads))

        self.session.run(tf.initialize_all_variables())
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
        self.summary_writer = tf.train.SummaryWriter("{}/{}-{}".format(logdir, timestamp, experiment), graph=self.session.graph, flush_secs=10)
        self.steps = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号