def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists):
"""
Update iteration data information: iteration, average cost, and for
each condition the mean cost over samples, step size, linear Guassian
controller entropies, and initial/final KL divergences for BADMM.
"""
avg_cost = np.mean(costs)
if pol_sample_lists is not None:
test_idx = algorithm._hyperparams['test_conditions']
# pol_sample_lists is a list of singletons
samples = [sl[0] for sl in pol_sample_lists]
pol_costs = [np.sum(algorithm.cost[idx].eval(s)[0])
for s, idx in zip(samples, test_idx)]
itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs))
else:
itr_data = '%3d | %8.2f' % (itr, avg_cost)
for m in range(algorithm.M):
cost = costs[m]
step = np.mean(algorithm.prev[m].step_mult * algorithm.base_kl_step)
entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar,
axis1=1, axis2=2)))
itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy)
if isinstance(algorithm, AlgorithmBADMM):
kl_div_i = algorithm.cur[m].pol_info.init_kl.mean()
kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean()
itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f)
elif isinstance(algorithm, AlgorithmMDGPS):
# TODO: Change for test/train better.
if test_idx == algorithm._hyperparams['train_conditions']:
itr_data += ' %8.2f' % (pol_costs[m])
else:
itr_data += ' %8s' % ("N/A")
self.append_output_text(itr_data)
评论列表
文章目录