utils.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:deepcpg 作者: cangermueller 项目源码 文件源码
def copy_weights(src_model, dst_model, must_exist=True):
    """Copy weights from `src_model` to `dst_model`.

    Parameters
    ----------
    src_model
        Keras source model.
    dst_model
        Keras destination model.
    must_exist: bool
        If `True`, raises `ValueError` if a layer in `dst_model` does not exist
        in `src_model`.

    Returns
    -------
    list
        Names of layers that were copied.
    """
    copied = []
    for dst_layer in dst_model.layers:
        for src_layer in src_model.layers:
            if src_layer.name == dst_layer.name:
                break
        if not src_layer:
            if must_exist:
                tmp = 'Layer "%s" not found!' % (src_layer.name)
                raise ValueError(tmp)
            else:
                continue
        dst_layer.set_weights(src_layer.get_weights())
        copied.append(dst_layer.name)
    return copied
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号