neural_network.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:dl4nlp 作者: yohokuno 项目源码 文件源码
def flatten_cost_gradient(cost_gradient_hetero, shapes):
    """
    Allow cost function to have heterogeneous parameters (which is not allowed in numpy array)
    :param cost_gradient_hetero: cost function that receives heterogeneous parameters
    :param shapes: list of shapes of parameter
    :return: cost function that receives concatenated parameters and returns concatenated gradients
    """
    def cost_gradient_wrapper(concatenated_parameters, input, output):
        all_parameters = []

        for shape in shapes:
            split_index = np.prod(shape)
            single_parameter, concatenated_parameters = np.split(concatenated_parameters, [split_index])
            single_parameter = single_parameter.reshape(shape)
            all_parameters.append(single_parameter)

        cost, gradients = cost_gradient_hetero(all_parameters, input, output)
        flatten_gradients = [gradient.flatten() for gradient in gradients]
        concatenated_gradients = np.concatenate(flatten_gradients)
        return cost, concatenated_gradients

    return cost_gradient_wrapper
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号