def load_params(self, f_, filter_=None):
di = pickle.load(f_)
if filter_ is None:
for k,v in di.items():
p = self._vars_di[k].get_value(borrow=True)
if p.shape != v.shape:
raise ValueError('Shape mismatch, need %s, got %s'%(v.shape, p.shape), p.shape)
self._vars_di[k].set_value(v)
else:
pat = re.compile(filter_)
for k,v in di.items():
if not pat.fullmatch(k): continue
p = self._vars_di[k].get_value(borrow=True)
if p.shape != v.shape:
raise ValueError('Shape mismatch, need %s, got %s'%(v.shape, p.shape), p.shape)
self._vars_di[k].set_value(v)
评论列表
文章目录