gps_training_gui.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:gps 作者: cbfinn 项目源码 文件源码
def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists):
        """
        Update iteration data information: iteration, average cost, and for
        each condition the mean cost over samples, step size, linear Guassian
        controller entropies, and initial/final KL divergences for BADMM.
        """
        avg_cost = np.mean(costs)
        if pol_sample_lists is not None:
            test_idx = algorithm._hyperparams['test_conditions']
            # pol_sample_lists is a list of singletons
            samples = [sl[0] for sl in pol_sample_lists]
            pol_costs = [np.sum(algorithm.cost[idx].eval(s)[0])
                    for s, idx in zip(samples, test_idx)]
            itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs))
        else:
            itr_data = '%3d | %8.2f' % (itr, avg_cost)
        for m in range(algorithm.M):
            cost = costs[m]
            step = np.mean(algorithm.prev[m].step_mult * algorithm.base_kl_step)
            entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar,
                    axis1=1, axis2=2)))
            itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy)
            if isinstance(algorithm, AlgorithmBADMM):
                kl_div_i = algorithm.cur[m].pol_info.init_kl.mean()
                kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean()
                itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f)
            elif isinstance(algorithm, AlgorithmMDGPS):
                # TODO: Change for test/train better.
                if test_idx == algorithm._hyperparams['train_conditions']:
                    itr_data += ' %8.2f' % (pol_costs[m])
                else:
                    itr_data += ' %8s' % ("N/A")
        self.append_output_text(itr_data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号