def extract_test_vals(query, target, query_field, target_field, test_df, is_test_df_sym):
""" Extract values that has query in the columns and target in the rows.
Args:
query (string)
target (string)
query_field (string): name of multiindex level in which to find query
target_field (string): name of multiindex level in which to find target
test_df (pandas multi-index df)
is_test_df_sym (bool): only matters if query == target; set to True to
avoid double-counting in the case of a symmetric matrix
Returns:
vals (numpy array)
"""
assert query in test_df.columns.get_level_values(query_field), (
"query {} is not in the {} level of the columns of test_df.".format(
query, query_field))
assert target in test_df.index.get_level_values(target_field), (
"target {} is not in the {} level of the index of test_df.".format(
target, target_field))
# Extract elements where query is in columns and target is in rows
target_in_rows_query_in_cols_df = test_df.loc[
test_df.index.get_level_values(target_field) == target,
test_df.columns.get_level_values(query_field) == query]
# If query == target AND the matrix is symmetric, need to take only triu
# of the extracted values in order to avoid double-counting
if query == target and is_test_df_sym:
mask = np.triu(np.ones(target_in_rows_query_in_cols_df.shape), k=1).astype(np.bool)
vals_with_nans = target_in_rows_query_in_cols_df.where(mask).values.flatten()
vals = vals_with_nans[~np.isnan(vals_with_nans)]
else:
vals = target_in_rows_query_in_cols_df.values.flatten()
return vals
评论列表
文章目录