def linear_mapping_weightnorm(inputs, out_dim, in_dim=None, dropout=1.0, var_scope_name="linear_mapping"):
with tf.variable_scope(var_scope_name):
input_shape = inputs.get_shape().as_list() # static shape. may has None
input_shape_tensor = tf.shape(inputs)
# use weight normalization (Salimans & Kingma, 2016) w = g* v/2-norm(v)
V = tf.get_variable('V', shape=[int(input_shape[-1]), out_dim], dtype=tf.float32, initializer=tf.random_normal_initializer(mean=0, stddev=tf.sqrt(dropout*1.0/int(input_shape[-1]))), trainable=True)
V_norm = tf.norm(V.initialized_value(), axis=0) # V shape is M*N, V_norm shape is N
g = tf.get_variable('g', dtype=tf.float32, initializer=V_norm, trainable=True)
b = tf.get_variable('b', shape=[out_dim], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=True) # weightnorm bias is init zero
assert len(input_shape) == 3
inputs = tf.reshape(inputs, [-1, input_shape[-1]])
inputs = tf.matmul(inputs, V)
inputs = tf.reshape(inputs, [input_shape_tensor[0], -1, out_dim])
#inputs = tf.matmul(inputs, V) # x*v
scaler = tf.div(g, tf.norm(V, axis=0)) # g/2-norm(v)
inputs = tf.reshape(scaler,[1, out_dim])*inputs + tf.reshape(b,[1, out_dim]) # x*v g/2-norm(v) + b
return inputs
评论列表
文章目录