monitor.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:wavenet 作者: rampage644 项目源码 文件源码
def layer_params(layer, param_name, attr_name):

    """Return parameters in a flattened array from the given layer or an empty
    array if the parameters are not found.

    Args:
        layer (~chainer.Link): The layer from which parameters are collected.
        param_name (str): Name of the parameter, ``'W'`` or ``'b'``.
        attr_name (str): Name of the attribute, ``'data'`` or ``'grad'``.

    Returns:
        array: Flattened array of parameters.
    """

    if isinstance(layer, chainer.Chain):
        # Nested chainer.Chain, aggregate all underlying statistics
        return layers_params(layer, param_name, attr_name)
    elif not hasattr(layer, param_name):
        return layer.xp.array([])

    params = getattr(layer, param_name)
    params = getattr(params, attr_name)
    return params.flatten()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号