optimize.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:fxnn 作者: khaotik 项目源码 文件源码
def compile(self,s_inputs_, s_loss_, v_params_, s_grads_=None, s_reg_=0, fetches_=None, updates_=None, givens_=None, trunc_grad_=None, profile_=False):
        def get_shared_shape(v):
            return v.get_value(borrow=True, return_internal_type=True).shape
        if type(s_inputs_) not in (list, tuple):
            s_inputs_ = [s_inputs_]
        if isinstance(updates_, dict):
            updates_= list(updates_.items())
        super(AdamSGD,self).compile(
            s_inputs_, s_loss_, v_params_, s_reg_=s_reg_, s_grads_=s_grads_, trunc_grad_=trunc_grad_)
        self.v_m = [th.shared(value=np.zeros(get_shared_shape(p), th.config.floatX), name='adam_m_'+p.name if p.name is not None else None) for p in v_params_]
        self.v_v = [th.shared(value=np.zeros(get_shared_shape(p), th.config.floatX), name='adam_v_'+p.name if p.name is not None else None) for p in v_params_]
        s_b1 = T.scalar('adam_b1'); s_b2 = T.scalar('adam_b2')
        s_b1s = T.scalar('adam_b1s'); s_b2s = T.scalar('adam_b2s')
        update_m = [(m, (m*s_b1 + (1.-s_b1)*g)) for m,g in zip(self.v_m,self.s_grads)]
        update_v = [(v, (v*s_b2 + (1.-s_b2)*g*g)) for v,g in zip(self.v_v,self.s_grads)]
        apply_grad = [(p, p-(s_b1s*m*self.s_lr)/(T.sqrt(s_b2s*v)+self.eps)) for p,m,v in zip(v_params_,self.v_m,self.v_v)]
        self.fn_train = th.function(
            inputs=[self.s_lr]+s_inputs_+[s_b1,s_b2,s_b1s,s_b2s],
            outputs=fetches_,
            updates=update_m+update_v+apply_grad+(updates_ if updates_ else []),
            on_unused_input='warn',
            givens=givens_, profile=profile_)
        self.fn_rst = th.function(inputs=[], updates=[(v, T.zeros_like(v)) for v in self.v_m+self.v_v], profile=profile_)
        return self.fn_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号