def _load_accumulators(self, main_loop):
"""Nasty method, use carefully"""
for cg_name, model in main_loop.models.iteritems():
source = numpy.load(self.path_to_accumulators.format(cg_name))
accums_dict = {name.replace("-", "/"): value
for name, value in source.items()}
source.close()
algo = main_loop.algorithm.algorithms[cg_name]
model_params = model.get_params()
steps = algo.steps.items()
for pidx in xrange(len(steps)):
# Get parameter name and its accumulators
p = steps[pidx][0]
name = [k for k, v in model_params.iteritems() if v == p][0]
accums = accums_dict[name]
# This is num_accums_per_param
col = len(accums)
for aidx in xrange(col):
algo.step_rule_updates[pidx*col+aidx][0].set_value(
accums[aidx])
评论列表
文章目录