def shrink_bgest(r,rvar,theta):
"""Bernoulli-Gaussian MMSE estimator
Perform MMSE estimation E[x|r]
for x ~ BernoulliGaussian(lambda,xvar1)
r|x ~ Normal(x,rvar)
The parameters theta[0],theta[1] represent
The variance of non-zero x[i]
xvar1 = abs(theta[0])
The probability of nonzero x[i]
lamba = 1/(exp(theta[1])+1)
"""
xvar1 = abs(theta[...,0])
loglam = theta[...,1] # log(1/lambda - 1)
beta = 1/(1+rvar/xvar1)
r2scale = r*r*beta/rvar
rho = tf.exp(loglam - .5*r2scale ) * tf.sqrt(1 +xvar1/rvar)
rho1 = rho+1
xhat = beta*r/rho1
dxdr = beta*((1+rho*(1+r2scale) ) / tf.square( rho1 ))
dxdr = tf.reduce_mean(dxdr,0)
return (xhat,dxdr)
评论列表
文章目录