BaseModel.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def _get_loss(self,labels):
        # build the self.loss tensor
        # This function could be overwritten

        #print("pred {} label{}".format(self.logit.dtype,labels.dtype))

        with tf.name_scope("Loss"):
            with tf.name_scope("cross_entropy"):
                labels = tf.cast(labels, tf.float32)
                #self.logit = tf.cast(self.logit, tf.float32)
                self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logit, labels=labels))
            self._get_l2_loss()
            with tf.name_scope("accuracy"):
                y_label = tf.argmax(labels, 1)
                yp_label = tf.argmax(self.logit, 1)
                correct_pred = tf.equal(yp_label,y_label)
                self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        with tf.name_scope("summary"):
            if self.flags.visualize:
                tf.summary.scalar(name='TRAIN_CrossEntropy', tensor=self.loss, collections=[tf.GraphKeys.SCALARS])
                tf.summary.scalar(name='TRAIN_Accuracy', tensor=self.acc, collections=[tf.GraphKeys.SCALARS])

                tf.summary.scalar(name='TRAIN_L2loss', tensor=self.l2loss, collections=[tf.GraphKeys.SCALARS]
)

                if 'acc' in self.flags.visualize:
                    tf.summary.histogram(name='pred', values=yp_label, collections=[tf.GraphKeys.FEATURE_MAPS
])
                    tf.summary.histogram(name='truth', values=y_label, collections=[tf.GraphKeys.FEATURE_MAPS
])
                    for cl in range(self.flags.classes):
                        tf.summary.histogram(name='pred%d'%cl, values=tf.slice(self.logit, [0,cl],[self.flags.batch_size, 1]), collections=[tf.GraphKeys.FEATURE_MAPS])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号