def _get_loss(self,labels):
# build the self.loss tensor
# This function could be overwritten
#print("pred {} label{}".format(self.logit.dtype,labels.dtype))
with tf.name_scope("Loss"):
with tf.name_scope("cross_entropy"):
labels = tf.cast(labels, tf.float32)
#self.logit = tf.cast(self.logit, tf.float32)
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logit, labels=labels))
self._get_l2_loss()
with tf.name_scope("accuracy"):
y_label = tf.argmax(labels, 1)
yp_label = tf.argmax(self.logit, 1)
correct_pred = tf.equal(yp_label,y_label)
self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
with tf.name_scope("summary"):
if self.flags.visualize:
tf.summary.scalar(name='TRAIN_CrossEntropy', tensor=self.loss, collections=[tf.GraphKeys.SCALARS])
tf.summary.scalar(name='TRAIN_Accuracy', tensor=self.acc, collections=[tf.GraphKeys.SCALARS])
tf.summary.scalar(name='TRAIN_L2loss', tensor=self.l2loss, collections=[tf.GraphKeys.SCALARS]
)
if 'acc' in self.flags.visualize:
tf.summary.histogram(name='pred', values=yp_label, collections=[tf.GraphKeys.FEATURE_MAPS
])
tf.summary.histogram(name='truth', values=y_label, collections=[tf.GraphKeys.FEATURE_MAPS
])
for cl in range(self.flags.classes):
tf.summary.histogram(name='pred%d'%cl, values=tf.slice(self.logit, [0,cl],[self.flags.batch_size, 1]), collections=[tf.GraphKeys.FEATURE_MAPS])
评论列表
文章目录