model.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:tensorflow_ocr 作者: BowieHsu 项目源码 文件源码
def loss(y_true_pixel, y_pred_pixel,
         y_true_link, y_pred_link,
         training_mask):
    '''
    return pixel loss and link loss 
    add OHEM mode
    '''

    pixel_shape = tf.shape(y_pred_pixel)
    pixel_label = tf.cast(tf.reshape(y_true_pixel,[pixel_shape[0],-1]), dtype = tf.int32)
    pixel_pred = tf.reshape(y_pred_pixel, [pixel_shape[0],-1, 2])

    pixel_scores = slim.softmax(pixel_pred)
    pixel_neg_scores = pixel_scores[:,:,0]
    pixel_pos_mask, pixel_neg_mask = get_pos_and_neg_masks(pixel_label)

    pixel_selected_mask = OHNM_batch(14, pixel_neg_scores, pixel_pos_mask, pixel_neg_mask)
    n_seg_pos = tf.reduce_sum(tf.cast(pixel_pos_mask, tf.float32))
    # classification_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = pixel_label, logits = pixel_pred))
    # classification_loss *= 2

    #cls_mining_loss_function
    with tf.name_scope('ohem_pixel_loss'):
        def has_pos():
            pixel_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = pixel_pred, labels = pixel_label)
            return tf.reduce_sum(pixel_loss * pixel_selected_mask)/n_seg_pos
        def no_pos():
            return tf.constant(.0)

        classification_loss = tf.cond(n_seg_pos > 0, has_pos, no_pos)

    #link_pos and link_neg loss function
    link_shape = tf.shape(y_pred_pixel)
    total_link_loss = []
    with tf.name_scope('link_loss'):
        for i in range(8):
            y_link = y_true_link[:,:,:,i]
            pred_link = y_pred_link[:, :, :, 2 * i: 2 * (i + 1)]
            link_label = tf.cast(tf.reshape(y_link, [-1]), dtype = tf.int32)
            link_pred = tf.reshape(pred_link, [-1, 2])
            link_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = link_pred, labels = link_label)
            link_pos_mask, link_neg_mask = get_pos_and_neg_masks(link_label)
            W_pixel = tf.reshape(pixel_selected_mask, [-1])
            W_link_pos = tf.cast(link_pos_mask, dtype = tf.float32) * W_pixel 
            W_link_neg = tf.cast(link_neg_mask, dtype = tf.float32) * W_pixel 
            link_pos_n = tf.reduce_sum(tf.cast(W_link_pos, dtype = tf.float32))
            link_neg_n = tf.reduce_sum(tf.cast(W_link_neg, dtype = tf.float32))

            link_pos_loss = tf.reduce_sum(link_loss * W_link_pos)/link_pos_n
            link_neg_loss = tf.reduce_sum(link_loss * W_link_neg)/link_neg_n
            total_link_loss.append(link_pos_loss + link_neg_loss)

    weight_link_loss = tf.reduce_sum(total_link_loss)

    tf.summary.scalar('classification_loss', classification_loss)
    tf.summary.scalar('link_loss', weight_link_loss)

    return weight_link_loss + 2 * classification_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号