def layer_params(layer, param_name, attr_name):
"""Return parameters in a flattened array from the given layer or an empty
array if the parameters are not found.
Args:
layer (~chainer.Link): The layer from which parameters are collected.
param_name (str): Name of the parameter, ``'W'`` or ``'b'``.
attr_name (str): Name of the attribute, ``'data'`` or ``'grad'``.
Returns:
array: Flattened array of parameters.
"""
if isinstance(layer, chainer.Chain):
# Nested chainer.Chain, aggregate all underlying statistics
return layers_params(layer, param_name, attr_name)
elif not hasattr(layer, param_name):
return layer.xp.array([])
params = getattr(layer, param_name)
params = getattr(params, attr_name)
return params.flatten()
评论列表
文章目录