def gelman_rubin_diagnostic(x, logger, mu=None):
m, n = x.shape[0], x.shape[1]
theta = np.mean(x, axis=1)
sigma = np.var(x, axis=1)
# theta_m = np.mean(theta, axis=0)
theta_m = mu if mu else np.mean(theta, axis=0)
b = float(n) / float(m-1) * np.sum((theta - theta_m) ** 2)
w = 1. / float(m) * np.sum(sigma, axis=0)
v = float(n-1) / float(n) * w + float(m+1) / float(m * n) * b
r_hat = np.sqrt(v / w)
logger.info('R: max [%f] min [%f]' % (np.max(r_hat), np.min(r_hat)))
return r_hat
评论列表
文章目录