def vector_to_parameters(vec, parameters):
"""Convert one vector to the parameters
Arguments:
vec (Variable): a single vector represents the parameters of a model.
parameters (Iterable[Variable]): an iterator of Variables that are the
parameters of a model.
"""
# Ensure vec of type Variable
if not isinstance(vec, Variable):
raise TypeError('expected torch.autograd.Variable, but got: {}'
.format(torch.typename(vec)))
# Flag for the device where the parameter is located
param_device = None
# Pointer for slicing the vector for each parameter
pointer = 0
for param in parameters:
# Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device)
# The length of the parameter
num_param = torch.prod(torch.LongTensor(list(param.size())))
# Slice the vector, reshape it, and replace the old data of the parameter
param.data = vec[pointer:pointer + num_param].view(param.size()).data
# Increment the pointer
pointer += num_param
评论列表
文章目录