monitor.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:wavenet 作者: rampage644 项目源码 文件源码
def layers_params(model, param_name, attr_name):

    """Return all parameters in a flattened array from the given model.

    Args:
        model (~chainer.Chain): The model from which parameters are collected.
        param_name (str): Name of the parameter, ``'W'`` or ``'b'``.
        attr_name (str): Name of the attribute, ``'data'`` or ``'grad'``.

    Returns:
        array: Flattened array of parameters.
    """

    xp = model.xp
    params = xp.array([], dtype=xp.float32)

    for param in model.params():
        if param.name == param_name:
            values = getattr(param, attr_name)
            values = values.flatten()
            params = xp.concatenate((params, values))  # Slow?

    return params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号