def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
''' fully connected layer '''
name = get_name('dense', counters)
with tf.variable_scope(name):
if init:
# data based initialization of parameters
V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)
V_norm = tf.nn.l2_normalize(V.initialized_value(), [0])
x_init = tf.matmul(x, V_norm)
m_init, v_init = tf.nn.moments(x_init, [0])
scale_init = init_scale/tf.sqrt(v_init + 1e-10)
g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)
b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)
x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units]))
if nonlinearity is not None:
x_init = nonlinearity(x_init)
return x_init
else:
V,g,b = get_vars_maybe_avg(['V','g','b'], ema)
tf.assert_variables_initialized([V,g,b])
# use weight normalization (Salimans & Kingma, 2016)
x = tf.matmul(x, V)
scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0]))
x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units])
# apply nonlinearity
if nonlinearity is not None:
x = nonlinearity(x)
return x
评论列表
文章目录