mfb_dis_net.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:unsupervised-2017-cvprw 作者: imatge-upc 项目源码 文件源码
def tower_loss(name_scope, mfb, use_pretrained_encoder, encoder_gradient_ratio=1.0):
    # get reconstruction and ground truth
    ac_loss = mfb.ac_loss

    weight_decay_loss_list = tf.get_collection('losses', name_scope)
    if use_pretrained_encoder:
        if encoder_gradient_ratio == 0.0:
            weight_decay_loss_list = [var for var in weight_decay_loss_list \
                                      if 'c3d' not in var.name and 'mapping' not in var.name]

    weight_decay_loss = 0.0
    if len(weight_decay_loss_list) > 0:
        weight_decay_loss = tf.add_n(weight_decay_loss_list)

    total_loss = weight_decay_loss * 100 + ac_loss

    return total_loss, ac_loss, weight_decay_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号