def _multi_head(heads, loss_weights=None):
"""Creates a MultiHead stemming from same logits/hidden layer.
Args:
heads: list of _Head objects.
loss_weights: optional list of weights to be used to combine losses from
each head. All losses are weighted equally if not provided.
Returns:
A _Head instance that combines multiple heads.
Raises:
ValueError: if heads and loss_weights have different size.
"""
if loss_weights:
if len(loss_weights) != len(heads):
raise ValueError("heads and loss_weights must have same size")
def _weighted_loss_combiner(losses):
if loss_weights:
if len(losses) != len(loss_weights):
raise ValueError("losses and loss_weights must have same size")
weighted_losses = []
for loss, weight in zip(losses, loss_weights):
weighted_losses.append(math_ops.multiply(loss, weight))
return math_ops.add_n(weighted_losses)
else:
return math_ops.add_n(losses)
return _MultiHead(heads, loss_combiner=_weighted_loss_combiner)
# TODO(zakaria): Make the classes public once we are ready for users to subclass
# them.
head.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录