def resnet_atrous_conv(x, channels, size=3, padding='SAME', stride=1, hole=1, batch_norm=False,
phase_test=None, activation=tf.nn.relu, name=None,
parameter_name=None, bn_name=None, scale_name=None, summarize_scale=False, info=DummyDict(), parameters={},
pre_adjust_batch_norm=False):
if parameter_name is None:
parameter_name = name
if scale_name is None:
scale_name = parameter_name
with tf.name_scope(name):
features = int(x.get_shape()[3])
f = channels
shape = [size, size, features, f]
W_init, W_shape = _pretrained_resnet_conv_weights_initializer(parameter_name, parameters,
info=info.get('init'),
pre_adjust_batch_norm=pre_adjust_batch_norm,
bn_name=bn_name, scale_name=scale_name)
b_init, b_shape = _pretrained_resnet_biases_initializer(scale_name, parameters,
info=info.get('init'),
pre_adjust_batch_norm=pre_adjust_batch_norm,
bn_name=bn_name)
assert W_shape is None or tuple(W_shape) == tuple(shape), "Incorrect weights shape for {} (file: {}, spec: {})".format(name, W_shape, shape)
assert b_shape is None or tuple(b_shape) == (f,), "Incorrect bias shape for {} (file: {}, spec; {})".format(name, b_shape, (f,))
with tf.variable_scope(name):
W = tf.get_variable('weights', shape, dtype=tf.float32,
initializer=W_init)
b = tf.get_variable('biases', [f], dtype=tf.float32,
initializer=b_init)
if hole == 1:
raw_conv0 = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding=padding)
else:
assert stride == 1
raw_conv0 = tf.nn.atrous_conv2d(x, W, rate=hole, padding=padding)
#conv0 = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding)
if stride > 1:
conv0 = tf.strided_slice(raw_conv0, [0, 0, 0, 0], raw_conv0.get_shape(), [1, stride, stride, 1])
else:
conv0 = raw_conv0
h1 = tf.reshape(tf.nn.bias_add(conv0, b), conv0.get_shape())
z = h1
if activation is not None:
z = activation(z)
if info.get('scale_summary'):
with tf.name_scope('activation'):
tf.summary.scalar('activation/' + name, tf.sqrt(tf.reduce_mean(z**2)))
info['activations'][name] = z
return z
评论列表
文章目录