def add_param(self, spec, shape, name, **kwargs):
param = self.add_param_plain(spec, shape, name, **kwargs)
if name is not None and name.startswith("W") and self.weight_normalization:
# Hacky: check if the parameter is a weight matrix. If so, apply weight normalization
if len(param.get_shape()) == 2:
v = param
g = self.add_param_plain(tf.ones_initializer, (shape[1],), name=name + "_wn/g")
param = v * (tf.reshape(g, (1, -1)) / tf.sqrt(tf.reduce_sum(tf.square(v), 0, keep_dims=True)))
elif len(param.get_shape()) == 4:
v = param
g = self.add_param_plain(tf.ones_initializer, (shape[3],), name=name + "_wn/g")
param = v * (tf.reshape(g, (1, 1, 1, -1)) / tf.sqrt(tf.reduce_sum(tf.square(v), [0, 1, 2],
keep_dims=True)))
else:
raise NotImplementedError
return param
评论列表
文章目录