keras_conversion.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deeplift 作者: kundajelab 项目源码 文件源码
def batchnorm_conversion(layer, name, verbose, **kwargs):
    import keras
    if (hasattr(keras,'__version__')):
        keras_version = float(keras.__version__[0:3])
    else:
        keras_version = 0.2
    if (keras_version <= 0.3):
        std = np.array(layer.running_std.get_value())
        epsilon = layer.epsilon
    else:
        std = np.sqrt(np.array(layer.running_std.get_value()+layer.epsilon))
        epsilon = 0
    return [blobs.BatchNormalization(
            name=name,
            verbose=verbose,
            gamma=np.array(layer.gamma.get_value()),
            beta=np.array(layer.beta.get_value()),
            axis=layer.axis,
            mean=np.array(layer.running_mean.get_value()),
            std=std,
            epsilon=epsilon)]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号