chart_utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:simple_rl 作者: david-abel 项目源码 文件源码
def compute_conf_intervals(data, cumulative=False):
    '''
    Args:
        data (list): A 3D matrix, [algorithm][instance][episode]
        cumulative (bool) *opt
    '''

    confidence_intervals_each_alg = [] # [alg][conf_inv_for_episode]

    for i, all_instances in enumerate(data):

        num_instances = len(data[i])
        num_episodes = len(data[i][0])

        all_instances = np.array(all_instances)
        alg_i_ci = []
        total_so_far = np.zeros(num_instances)
        for j in xrange(num_episodes):
            # Compute datum for confidence interval.
            episode_j_all_instances = all_instances[:, j]

            if cumulative:
                # Cumulative.
                summed_vector = np.add(episode_j_all_instances, total_so_far)
                total_so_far = np.add(episode_j_all_instances, total_so_far)
                episode_j_all_instances = summed_vector

            # Compute the interval and add it to list.
            conf_interv = compute_single_conf_interval(episode_j_all_instances)
            alg_i_ci.append(conf_interv)

        confidence_intervals_each_alg.append(alg_i_ci)

    return confidence_intervals_each_alg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号