optimize.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:fxnn 作者: khaotik 项目源码 文件源码
def compile(
        self,s_inputs_, s_loss_,
        v_params_, s_grads_=None, s_reg_=0,
        fetches_=None, updates_=None, givens_=None,
        trunc_grad_=None, profile_=False):
        if type(s_inputs_) not in (list, tuple):
            s_inputs_ = [s_inputs_]
        if isinstance(updates_, dict):
            updates_= list(updates_.items())
        super(VanillaSGD,self).compile(
            s_inputs_, s_loss_, v_params_, s_reg_=s_reg_, s_grads_=s_grads_, trunc_grad_=trunc_grad_)
        apply_grad = [(p, p-g*self.s_lr) for p,g in zip( v_params_,self.s_grads)]
        self.fn_train = th.function(
            [self.s_lr]+s_inputs_,
            fetches_,
            updates=apply_grad+(updates_ if updates_ else []),
            givens=givens_,
            on_unused_input='warn',
            profile = profile_
        )
        return self.fn_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号