updates.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pl-cnn 作者: oval-group 项目源码 文件源码
def get_updates(nnet,
                train_obj,
                trainable_params):

    implemented_solvers = ("nesterov", "adagrad", "adadelta", "adam")

    if not hasattr(nnet, "solver") or nnet.solver not in implemented_solvers:
        nnet.sgd_solver = "nesterov"
    else:
        nnet.sgd_solver = nnet.solver

    if nnet.sgd_solver == "nesterov":
        updates = l_updates.nesterov_momentum(train_obj,
                                              trainable_params,
                                              learning_rate=Cfg.learning_rate,
                                              momentum=0.9)

    elif nnet.sgd_solver == "adagrad":
        updates = l_updates.adagrad(train_obj,
                                    trainable_params,
                                    learning_rate=Cfg.learning_rate)

    elif nnet.sgd_solver == "adadelta":
        updates = l_updates.adadelta(train_obj,
                                     trainable_params,
                                     learning_rate=Cfg.learning_rate)

    elif nnet.sgd_solver == "adam":
        updates = l_updates.adam(train_obj,
                                 trainable_params,
                                 learning_rate=Cfg.learning_rate)

    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号