experiment.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Graph-CNN 作者: fps7806 项目源码 文件源码
def create_loss_function(self):
        with tf.variable_scope('loss') as scope:
            self.print_ext('Creating loss function and summaries')
            cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.net.current_V, labels=self.net.labels))

            correct_prediction = tf.cast(tf.equal(tf.argmax(self.net.current_V, 1), self.net.labels), tf.float32)
            accuracy = tf.reduce_mean(correct_prediction)

            # we have 2 variables that will keep track of the best accuracy obtained in training/testing batch
            # SHOULD ONLY BE USED IF test_batch_size == ALL TEST SAMPLES
            self.max_acc_train = tf.Variable(tf.zeros([]), name="max_acc_train")
            self.max_acc_test = tf.Variable(tf.zeros([]), name="max_acc_test")
            max_acc = tf.cond(self.net.is_training, lambda: tf.assign(self.max_acc_train, tf.maximum(self.max_acc_train, accuracy)), lambda: tf.assign(self.max_acc_test, tf.maximum(self.max_acc_test, accuracy)))

            tf.add_to_collection('losses', cross_entropy)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.scalar('max_accuracy', max_acc)
            tf.summary.scalar('cross_entropy', cross_entropy)

            # if silent == false display these statistics:
            self.reports['accuracy'] = accuracy
            self.reports['max acc.'] = max_acc
            self.reports['cross_entropy'] = cross_entropy

    # check if the model has a saved iteration and return the latest iteration step
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号