raputil.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def nlfunc(r,sc,grid,gg,return_gradient=True):
    'returns xhat_nl = rhat_nl * interp( rhat_nl / sc,grid,gg) and optionally the gradient of xhat_nl wrt rhat_nl'
    g = r * np.interp(r/sc,grid,gg)
    if return_gradient:
        #I had some code that computed the gradient, but it was far more complicated and no faster than just computing the empirical gradient
        # technically, this computes a subgradient
        dr = sc * (grid[1]-grid[0]) * 1e-3
        dgdr = (nlfunc(r+.5*dr,sc,grid,gg,False) - nlfunc(r-.5*dr,sc,grid,gg,False)) / dr
        return (g,dgdr)
    else:
        return g
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号