def tower_loss(name_scope, mfb, use_pretrained_encoder, encoder_gradient_ratio=1.0):
# get reconstruction and ground truth
ac_loss = mfb.ac_loss
weight_decay_loss_list = tf.get_collection('losses', name_scope)
if use_pretrained_encoder:
if encoder_gradient_ratio == 0.0:
weight_decay_loss_list = [var for var in weight_decay_loss_list \
if 'c3d' not in var.name and 'mapping' not in var.name]
weight_decay_loss = 0.0
if len(weight_decay_loss_list) > 0:
weight_decay_loss = tf.add_n(weight_decay_loss_list)
total_loss = weight_decay_loss * 100 + ac_loss
return total_loss, ac_loss, weight_decay_loss
评论列表
文章目录