pairwise.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def calc_tvd(label_dict,attr):
    '''
    attr should be a 0,1 pandas dataframe with
    columns corresponding to label names

    for example:
    names=zip(*self.graph)[0]
    calc_tvd(label_dict,attr[names])

    label_dict should be a dictionary key:1d-array of samples
    '''
    ####Calculate Total Variation####
    if np.min(attr.values)<0:
        raise ValueError('calc_tvd received \
                 attr that may not have been in {0,1}')

    label_names=label_dict.keys()
    attr=attr[label_names]

    df2=attr.drop_duplicates()
    df2 = df2.reset_index(drop = True).reset_index()
    df2=df2.rename(columns = {'index':'ID'})
    real_data_id=pd.merge(attr,df2)
    real_counts = pd.value_counts(real_data_id['ID'])
    real_pdf=real_counts/len(attr)

    label_list_dict={k:np.round(v.ravel()) for k,v in label_dict.items()}
    df_dat=pd.DataFrame.from_dict(label_list_dict)
    dat_id=pd.merge(df_dat,df2,on=label_names,how='left')
    dat_counts=pd.value_counts(dat_id['ID'])
    dat_pdf = dat_counts / dat_counts.sum()
    diff=real_pdf.subtract(dat_pdf, fill_value=0)
    tvd=0.5*diff.abs().sum()
    return tvd
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号