def conv(inp, name, filter_size, out_channels, stride=1,
padding='SAME', nonlinearity=None, init_scale=1.0, dilation=None):
"""Convolutional layer.
If tf.GLOBAL['init'] is true, this creates the layers paramenters (g, b, W) : L(x) = g|W| (*) x + b
Args:
x: input tensor
name (str): variable scope name
filter_size (int pair): filter size
out_channels (int): number of output channels
strid (int): horizontal and vertical stride
padding (str): padding mode
nonlinearity (func): activation function
init_scale: initial scale for the weights and bias variables
dilation: optional dilation rate
"""
with tf.variable_scope(name):
strides = [1, stride, stride, 1]
in_channels = inp.get_shape().as_list()[3]
if tf.GLOBAL['init']:
V = get_variable('V', shape=tuple(filter_size) + (in_channels, out_channels), dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
V_norm = tf.nn.l2_normalize(V.initialized_value(), [0, 1, 2])
if dilation is None:
out = tf.nn.conv2d(inp, V_norm, strides, padding)
else:
assert(stride == 1)
out = tf.nn.atrous_conv2d(inp, V_norm, dilation, padding)
m_init, v_init = tf.nn.moments(out, [0, 1, 2])
scale_init = init_scale / tf.sqrt(v_init + 1e-8)
g = get_variable('g', shape=None, dtype=tf.float32, initializer=scale_init, trainable=True, regularizer=tf.contrib.layers.l2_regularizer(tf.GLOBAL['reg']))
b = get_variable('b', shape=None, dtype=tf.float32, initializer=-m_init * scale_init, trainable=True, regularizer=tf.contrib.layers.l2_regularizer(tf.GLOBAL['reg']))
out = tf.reshape(scale_init, [1, 1, 1, out_channels]) * (out - tf.reshape(m_init, [1, 1, 1, out_channels]))
if nonlinearity is not None:
out = nonlinearity(out)
else:
V, g, b = get_variable('V'), get_variable('g'), get_variable('b')
tf.assert_variables_initialized([V, g, b])
W = g[None, None, None] * tf.nn.l2_normalize(V, [0, 1, 2])
if dilation is None:
out = tf.nn.conv2d(inp, W, strides, padding) + b[None, None, None]
else:
assert(stride == 1)
out = tf.nn.atrous_conv2d(inp, W, dilation, padding) + b[None, None, None]
if nonlinearity is not None:
out = nonlinearity(out)
return out
评论列表
文章目录