def compile(
self,s_inputs_, s_loss_,
v_params_, s_grads_=None, s_reg_=0,
fetches_=None, updates_=None, givens_=None,
trunc_grad_=None, profile_=False):
if type(s_inputs_) not in (list, tuple):
s_inputs_ = [s_inputs_]
if isinstance(updates_, dict):
updates_= list(updates_.items())
super(VanillaSGD,self).compile(
s_inputs_, s_loss_, v_params_, s_reg_=s_reg_, s_grads_=s_grads_, trunc_grad_=trunc_grad_)
apply_grad = [(p, p-g*self.s_lr) for p,g in zip( v_params_,self.s_grads)]
self.fn_train = th.function(
[self.s_lr]+s_inputs_,
fetches_,
updates=apply_grad+(updates_ if updates_ else []),
givens=givens_,
on_unused_input='warn',
profile = profile_
)
return self.fn_train
评论列表
文章目录