base.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _loss_softmax(self, logits, labels, is_training, weighted=False):
        log.info('Using softmax loss')
        labels = tf.cast(labels, tf.int64)
        if tf.rank(labels) != 2:
            labels = tf.one_hot(labels, self.num_classes)
        if weighted:
            weights = self._compute_weights(labels)
            weights = tf.reduce_max(tf.multiply(weights, labels), axis=1)
            ce_loss = tf.losses.softmax_cross_entropy(
                labels, logits=logits, weights=weights, label_smoothing=self.label_smoothing, scope='cross_entropy_loss')
        else:
            ce_loss = tf.nn.softmax_cross_entropy_with_logits(
                labels=labels, logits=logits, name='cross_entropy_loss')
        ce_loss_mean = tf.reduce_mean(ce_loss, name='cross_entropy')
        if is_training:
            tf.add_to_collection('losses', ce_loss_mean)

            l2_loss = tf.add_n(tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES))
            l2_loss = l2_loss * self.cnf.get('l2_reg', 0.0)
            tf.add_to_collection('losses', l2_loss)

            return tf.add_n(tf.get_collection('losses'), name='total_loss')
        else:
            return ce_loss_mean
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号