def compile(self,s_inputs_, s_loss_, v_params_, s_grads_=None, s_reg_=0, fetches_=None, updates_=None, givens_=None, trunc_grad_=None, profile_=False):
def get_shared_shape(v):
return v.get_value(borrow=True, return_internal_type=True).shape
if type(s_inputs_) not in (list, tuple):
s_inputs_ = [s_inputs_]
if isinstance(updates_, dict):
updates_= list(updates_.items())
super(AdamSGD,self).compile(
s_inputs_, s_loss_, v_params_, s_reg_=s_reg_, s_grads_=s_grads_, trunc_grad_=trunc_grad_)
self.v_m = [th.shared(value=np.zeros(get_shared_shape(p), th.config.floatX), name='adam_m_'+p.name if p.name is not None else None) for p in v_params_]
self.v_v = [th.shared(value=np.zeros(get_shared_shape(p), th.config.floatX), name='adam_v_'+p.name if p.name is not None else None) for p in v_params_]
s_b1 = T.scalar('adam_b1'); s_b2 = T.scalar('adam_b2')
s_b1s = T.scalar('adam_b1s'); s_b2s = T.scalar('adam_b2s')
update_m = [(m, (m*s_b1 + (1.-s_b1)*g)) for m,g in zip(self.v_m,self.s_grads)]
update_v = [(v, (v*s_b2 + (1.-s_b2)*g*g)) for v,g in zip(self.v_v,self.s_grads)]
apply_grad = [(p, p-(s_b1s*m*self.s_lr)/(T.sqrt(s_b2s*v)+self.eps)) for p,m,v in zip(v_params_,self.v_m,self.v_v)]
self.fn_train = th.function(
inputs=[self.s_lr]+s_inputs_+[s_b1,s_b2,s_b1s,s_b2s],
outputs=fetches_,
updates=update_m+update_v+apply_grad+(updates_ if updates_ else []),
on_unused_input='warn',
givens=givens_, profile=profile_)
self.fn_rst = th.function(inputs=[], updates=[(v, T.zeros_like(v)) for v in self.v_m+self.v_v], profile=profile_)
return self.fn_train
评论列表
文章目录