special.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:rllabplusplus 作者: shaneshixiang 项目源码 文件源码
def normalize_updates(old_mean, old_std, new_mean, new_std, old_W, old_b):
    """
    Compute the updates for normalizing the last (linear) layer of a neural
    network
    """
    # Make necessary transformation so that
    # (W_old * h + b_old) * std_old + mean_old == \
    #   (W_new * h + b_new) * std_new + mean_new
    new_W = old_W * old_std[0] / (new_std[0] + 1e-6)
    new_b = (old_b * old_std[0] + old_mean[0] - new_mean[0]) / (new_std[0] + 1e-6)
    return OrderedDict([
        (old_W, TT.cast(new_W, old_W.dtype)),
        (old_b, TT.cast(new_b, old_b.dtype)),
        (old_mean, new_mean),
        (old_std, new_std),
    ])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号