def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
''' fully connected layer '''
name = get_name('dense', counters)
with tf.variable_scope(name):
if init:
# data based initialization of parameters
V = tf.get_variable('V', [int(x.get_shape()[
1]), num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)
V_norm = tf.nn.l2_normalize(V.initialized_value(), [0])
x_init = tf.matmul(x, V_norm)
m_init, v_init = tf.nn.moments(x_init, [0])
scale_init = init_scale / tf.sqrt(v_init + 1e-10)
g = tf.get_variable('g', dtype=tf.float32,
initializer=scale_init, trainable=True)
b = tf.get_variable('b', dtype=tf.float32,
initializer=-m_init * scale_init, trainable=True)
x_init = tf.reshape(
scale_init, [1, num_units]) * (x_init - tf.reshape(m_init, [1, num_units]))
if nonlinearity is not None:
x_init = nonlinearity(x_init)
return x_init
else:
V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema)
tf.assert_variables_initialized([V, g, b])
# use weight normalization (Salimans & Kingma, 2016)
x = tf.matmul(x, V)
scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0]))
x = tf.reshape(scaler, [1, num_units]) * \
x + tf.reshape(b, [1, num_units])
# apply nonlinearity
if nonlinearity is not None:
x = nonlinearity(x)
return x
评论列表
文章目录