resnet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dlbench 作者: hclhkbu 项目源码 文件源码
def bn(x, c):
    x_shape = x.get_shape() 
    params_shape = x_shape[-1:]

    if c['use_bias']:
        bias = _get_variable('bias', params_shape,
                             initializer=tf.zeros_initializer())
        return x + bias

    batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True,
                         'center': True}

    x = tf.contrib.layers.batch_norm(x, 
                                     is_training=c['is_training'],
                                     fused=True,
                                     data_format=DATA_FORMAT,
                                     **batch_norm_config)
    return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号