nn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:image_captioning 作者: bityangke 项目源码 文件源码
def weight(name, shape, init='he', range=1, stddev=0.33, init_val=None):
    if init_val is not None:
        initializer = tf.constant_initializer(init_val)
    elif init == 'uniform':
        initializer = tf.random_uniform_initializer(-range, range)
    elif init == 'normal':
        initializer = tf.random_normal_initializer(stddev = stddev)
    elif init == 'he':
        fan_in, _ = _get_dims(shape)
        std = math.sqrt(2.0 / fan_in)
        initializer = tf.random_normal_initializer(stddev = std)
    elif init == 'xavier':
        fan_in, fan_out = _get_dims(shape)
        range = math.sqrt(6.0 / (fan_in + fan_out))
        initializer = tf.random_uniform_initializer(-range, range)
    else:
        initializer = tf.truncated_normal_initializer(stddev = stddev)

    var = tf.get_variable(name, shape, initializer = initializer)
    tf.add_to_collection('l2', tf.nn.l2_loss(var))
    return var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号