def get_loss(pred, label, end_points, reg_weight=0.001):
""" pred: BxNxC,
label: BxN, """
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label)
classify_loss = tf.reduce_mean(loss)
tf.scalar_summary('classify loss', classify_loss)
# Enforce the transformation as orthogonal matrix
transform = end_points['transform'] # BxKxK
K = transform.get_shape()[1].value
mat_diff = tf.matmul(transform, tf.transpose(transform, perm=[0,2,1]))
mat_diff -= tf.constant(np.eye(K), dtype=tf.float32)
mat_diff_loss = tf.nn.l2_loss(mat_diff)
tf.scalar_summary('mat_loss', mat_diff_loss)
return classify_loss + mat_diff_loss * reg_weight
评论列表
文章目录