test_gan_metrics.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _expected_fid(real_imgs, gen_imgs):
    m = np.mean(real_imgs, axis=0)
    m_v = np.mean(gen_imgs, axis=0)
    sigma = np.cov(real_imgs, rowvar=False)
    sigma_v = np.cov(gen_imgs, rowvar=False)
    sqcc = scp_linalg.sqrtm(np.dot(sigma, sigma_v))
    mean = np.square(m - m_v).sum()
    trace = np.trace(sigma + sigma_v - 2 * sqcc)
    fid = mean + trace
    return fid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号