def add_regularization_loss(self):
weights = [w for w in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if w.name.split('/')[-1] in ('kernel:0', 'weights:0')]
if self.config.l2_regularization == 0.0:
return 0
return tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.config.l2_regularization), weights)
base_aligner.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录