def layers_params(model, param_name, attr_name):
"""Return all parameters in a flattened array from the given model.
Args:
model (~chainer.Chain): The model from which parameters are collected.
param_name (str): Name of the parameter, ``'W'`` or ``'b'``.
attr_name (str): Name of the attribute, ``'data'`` or ``'grad'``.
Returns:
array: Flattened array of parameters.
"""
xp = model.xp
params = xp.array([], dtype=xp.float32)
for param in model.params():
if param.name == param_name:
values = getattr(param, attr_name)
values = values.flatten()
params = xp.concatenate((params, values)) # Slow?
return params
评论列表
文章目录