def flatten_cost_gradient(cost_gradient_hetero, shapes):
"""
Allow cost function to have heterogeneous parameters (which is not allowed in numpy array)
:param cost_gradient_hetero: cost function that receives heterogeneous parameters
:param shapes: list of shapes of parameter
:return: cost function that receives concatenated parameters and returns concatenated gradients
"""
def cost_gradient_wrapper(concatenated_parameters, input, output):
all_parameters = []
for shape in shapes:
split_index = np.prod(shape)
single_parameter, concatenated_parameters = np.split(concatenated_parameters, [split_index])
single_parameter = single_parameter.reshape(shape)
all_parameters.append(single_parameter)
cost, gradients = cost_gradient_hetero(all_parameters, input, output)
flatten_gradients = [gradient.flatten() for gradient in gradients]
concatenated_gradients = np.concatenate(flatten_gradients)
return cost, concatenated_gradients
return cost_gradient_wrapper
评论列表
文章目录