sip.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:psp 作者: cmap 项目源码 文件源码
def extract_test_vals(query, target, query_field, target_field, test_df, is_test_df_sym):
    """ Extract values that has query in the columns and target in the rows.

    Args:
        query (string)
        target (string)
        query_field (string): name of multiindex level in which to find query
        target_field (string): name of multiindex level in which to find target
        test_df (pandas multi-index df)
        is_test_df_sym (bool): only matters if query == target; set to True to
            avoid double-counting in the case of a symmetric matrix

    Returns:
        vals (numpy array)

    """
    assert query in test_df.columns.get_level_values(query_field), (
        "query {} is not in the {} level of the columns of test_df.".format(
            query, query_field))

    assert target in test_df.index.get_level_values(target_field), (
        "target {} is not in the {} level of the index of test_df.".format(
            target, target_field))

    # Extract elements where query is in columns and target is in rows
    target_in_rows_query_in_cols_df = test_df.loc[
            test_df.index.get_level_values(target_field) == target,
            test_df.columns.get_level_values(query_field) == query]

    # If query == target AND the matrix is symmetric, need to take only triu
    # of the extracted values in order to avoid double-counting
    if query == target and is_test_df_sym:
        mask = np.triu(np.ones(target_in_rows_query_in_cols_df.shape), k=1).astype(np.bool)
        vals_with_nans = target_in_rows_query_in_cols_df.where(mask).values.flatten()
        vals = vals_with_nans[~np.isnan(vals_with_nans)]

    else:
        vals = target_in_rows_query_in_cols_df.values.flatten()

    return vals
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号