imgconvnets.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def _build_training_graph(self, logits, labels, learning_rate):
        """
        Build the training graph.
        Args:
            logits: Logits tensor, float - [batch_size, class_count].
            labels: Labels tensor, int32 - [batch_size], with values in the range
                [0, class_count).
            learning_rate: The learning rate for the optimization.
        Returns:
            train_op: The Op for training.
            loss: The Op for calculating loss.
        """
        # Create an operation that calculates loss.
        labels = tf.to_int64(labels)
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels, name='xentropy')
        loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

        correct_predict = tf.nn.in_top_k(logits, labels, 1)
        accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))

        return train_op, loss, accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号