netools.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:nature_methods_multicut_pipeline 作者: ilastik 项目源码 文件源码
def he(shape, gain=1., dtype=th.config.floatX):
    if len(shape) == 4:
        # 2D convolutions
        fmapsin = shape[1]
        fov = np.prod(shape[2:])
    elif len(shape) == 5:
        # 3D convolutions
        fmapsin = shape[2]
        fov = np.prod(shape[3:]) * shape[1]
    else:
        raise NotImplementedError

    # Parse gain
    if isinstance(gain, str):
        if gain.lower() == 'relu':
            gain = 2.
        elif gain.lower() in ['sigmoid', 'linear', 'tanh']:
            gain = 1.

    # Compute variance for He init
    var = gain/(fmapsin * fov)
    # Build kernel
    ker = np.random.normal(loc=0., scale=np.sqrt(var), size=tuple(shape)).astype(dtype)
    return ker


# Training Monitors
# Batch Number
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号