def _expected_fid(real_imgs, gen_imgs):
m = np.mean(real_imgs, axis=0)
m_v = np.mean(gen_imgs, axis=0)
sigma = np.cov(real_imgs, rowvar=False)
sigma_v = np.cov(gen_imgs, rowvar=False)
sqcc = scp_linalg.sqrtm(np.dot(sigma, sigma_v))
mean = np.square(m - m_v).sum()
trace = np.trace(sigma + sigma_v - 2 * sqcc)
fid = mean + trace
return fid
评论列表
文章目录