def lindisc(X,p,t):
''' Linear MMD '''
it = tf.where(t>0)[:,0]
ic = tf.where(t<1)[:,0]
Xc = tf.gather(X,ic)
Xt = tf.gather(X,it)
mean_control = tf.reduce_mean(Xc,reduction_indices=0)
mean_treated = tf.reduce_mean(Xt,reduction_indices=0)
c = tf.square(2*p-1)*0.25
f = tf.sign(p-0.5)
mmd = tf.reduce_sum(tf.square(p*mean_treated - (1-p)*mean_control))
mmd = f*(p-0.5) + safe_sqrt(c + mmd)
return mmd
评论列表
文章目录