basic_segment_statistics.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:kaggle-seizure-prediction 作者: sics-lm 项目源码 文件源码
def get_subject_metric(stats_df, metric_name, aggregator='{dataframe}.median()', channel_ordering=None, use_cache=True):
    """
    Returns the metric given by stats df as a NDArray-like of shape (n_channels, 1), obtained by aggregating the
    given metric from the dataframe.
    :param stats_df: The statistics dataframe acquired from read_stats.
    :param metric_name: The metric to collect.
    :param aggregator: A string with an expression used to aggregate the per segment statistic to a single statistic for
                       the whole subject.
    :param channel_ordering: An optional ordered sequence of channel names, which will ensure that the outputted
    statistics vector has the same order as the segment which the statistic should be applied on.
    :param use_cache: If True, the metrics will be cached by the function so that calling it multiple times for the
                      same subject is fast.
    :return: A NDArray of shape (n_channels, 1) where each element along axis 0 correspond to the aggregated statistic
             that channel.
    """
    cache = get_subject_metric.cache
    assert isinstance(stats_df, pd.DataFrame)
    if use_cache and id(stats_df) in cache and (metric_name, aggregator) in cache[id(stats_df)]:
        return cache[id(stats_df)][(metric_name, aggregator)]

    # The stats dataframes have a 2-level column index, where the first level are the channel names and the seconde
    # the metric name. To get the metric but keep the channels we slice the first level with all the entries using
    # slice(None), this is equivalent to [:] for regular slices.
    if channel_ordering is None:
        segment_metrics = stats_df.loc[:, (slice(None), metric_name)]
    else:
        segment_metrics = stats_df.loc[:, (channel_ordering, metric_name)]
    aggregated_metric = eval(aggregator.format(dataframe='segment_metrics'))
    added_axis = aggregated_metric[:, np.newaxis]
    cache[id(stats_df)][(metric_name, aggregator)] = added_axis
    return added_axis
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号