resnet152_bn.py 文件源码

python
阅读 87 收藏 0 点赞 0 评论 0

项目:self-supervision 作者: gustavla 项目源码 文件源码
def resnet_atrous_conv(x, channels, size=3, padding='SAME', stride=1, hole=1, batch_norm=False,
         phase_test=None, activation=tf.nn.relu, name=None,
         parameter_name=None, bn_name=None, scale_name=None, summarize_scale=False, info=DummyDict(), parameters={},
         pre_adjust_batch_norm=False):
    if parameter_name is None:
        parameter_name = name
    if scale_name is None:
        scale_name = parameter_name
    with tf.name_scope(name):
        features = int(x.get_shape()[3])
        f = channels
        shape = [size, size, features, f]

        W_init, W_shape = _pretrained_resnet_conv_weights_initializer(parameter_name, parameters,
                                                          info=info.get('init'),
                                                          pre_adjust_batch_norm=pre_adjust_batch_norm,
                                                          bn_name=bn_name, scale_name=scale_name)
        b_init, b_shape = _pretrained_resnet_biases_initializer(scale_name, parameters,
                                                    info=info.get('init'),
                                                    pre_adjust_batch_norm=pre_adjust_batch_norm,
                                                    bn_name=bn_name)

        assert W_shape is None or tuple(W_shape) == tuple(shape), "Incorrect weights shape for {} (file: {}, spec: {})".format(name, W_shape, shape)
        assert b_shape is None or tuple(b_shape) == (f,), "Incorrect bias shape for {} (file: {}, spec; {})".format(name, b_shape, (f,))

        with tf.variable_scope(name):
            W = tf.get_variable('weights', shape, dtype=tf.float32,
                                initializer=W_init)
            b = tf.get_variable('biases', [f], dtype=tf.float32,
                                initializer=b_init)

        if hole == 1:
            raw_conv0 = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding=padding)
        else:
            assert stride == 1
            raw_conv0 = tf.nn.atrous_conv2d(x, W, rate=hole, padding=padding)
        #conv0 = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding)
        if stride > 1:
            conv0 = tf.strided_slice(raw_conv0, [0, 0, 0, 0], raw_conv0.get_shape(), [1, stride, stride, 1])
        else:
            conv0 = raw_conv0
        h1 = tf.reshape(tf.nn.bias_add(conv0, b), conv0.get_shape())

        z = h1

    if activation is not None:
        z = activation(z)

    if info.get('scale_summary'):
        with tf.name_scope('activation'):
            tf.summary.scalar('activation/' + name, tf.sqrt(tf.reduce_mean(z**2)))

    info['activations'][name] = z
    return z
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号