def get_subject_metric(stats_df, metric_name, aggregator='{dataframe}.median()', channel_ordering=None, use_cache=True):
"""
Returns the metric given by stats df as a NDArray-like of shape (n_channels, 1), obtained by aggregating the
given metric from the dataframe.
:param stats_df: The statistics dataframe acquired from read_stats.
:param metric_name: The metric to collect.
:param aggregator: A string with an expression used to aggregate the per segment statistic to a single statistic for
the whole subject.
:param channel_ordering: An optional ordered sequence of channel names, which will ensure that the outputted
statistics vector has the same order as the segment which the statistic should be applied on.
:param use_cache: If True, the metrics will be cached by the function so that calling it multiple times for the
same subject is fast.
:return: A NDArray of shape (n_channels, 1) where each element along axis 0 correspond to the aggregated statistic
that channel.
"""
cache = get_subject_metric.cache
assert isinstance(stats_df, pd.DataFrame)
if use_cache and id(stats_df) in cache and (metric_name, aggregator) in cache[id(stats_df)]:
return cache[id(stats_df)][(metric_name, aggregator)]
# The stats dataframes have a 2-level column index, where the first level are the channel names and the seconde
# the metric name. To get the metric but keep the channels we slice the first level with all the entries using
# slice(None), this is equivalent to [:] for regular slices.
if channel_ordering is None:
segment_metrics = stats_df.loc[:, (slice(None), metric_name)]
else:
segment_metrics = stats_df.loc[:, (channel_ordering, metric_name)]
aggregated_metric = eval(aggregator.format(dataframe='segment_metrics'))
added_axis = aggregated_metric[:, np.newaxis]
cache[id(stats_df)][(metric_name, aggregator)] = added_axis
return added_axis
basic_segment_statistics.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录