def calc_tvd(label_dict,attr):
'''
attr should be a 0,1 pandas dataframe with
columns corresponding to label names
for example:
names=zip(*self.graph)[0]
calc_tvd(label_dict,attr[names])
label_dict should be a dictionary key:1d-array of samples
'''
####Calculate Total Variation####
if np.min(attr.values)<0:
raise ValueError('calc_tvd received \
attr that may not have been in {0,1}')
label_names=label_dict.keys()
attr=attr[label_names]
df2=attr.drop_duplicates()
df2 = df2.reset_index(drop = True).reset_index()
df2=df2.rename(columns = {'index':'ID'})
real_data_id=pd.merge(attr,df2)
real_counts = pd.value_counts(real_data_id['ID'])
real_pdf=real_counts/len(attr)
label_list_dict={k:np.round(v.ravel()) for k,v in label_dict.items()}
df_dat=pd.DataFrame.from_dict(label_list_dict)
dat_id=pd.merge(df_dat,df2,on=label_names,how='left')
dat_counts=pd.value_counts(dat_id['ID'])
dat_pdf = dat_counts / dat_counts.sum()
diff=real_pdf.subtract(dat_pdf, fill_value=0)
tvd=0.5*diff.abs().sum()
return tvd
评论列表
文章目录