def _get_weight_variable(self, layer_name, name, shape, L2=1):
wname = '%s/%s:0'%(layer_name,name)
fanin, fanout = shape[-2:]
for dim in shape[:-2]:
fanin *= float(dim)
fanout *= float(dim)
sigma = self._xavi_norm(fanin, fanout)
if self.weights is None or wname not in self.weights:
w1 = tf.get_variable(name,initializer=tf.truncated_normal(shape = shape,
mean=0,stddev = sigma))
print('{:>23} {:>23}'.format(wname, 'randomly initialize'))
else:
w1 = tf.get_variable(name, shape = shape,
initializer=tf.constant_initializer(value=self.weights[wname],dtype=tf.float32))
self.loaded_weights[wname]=1
if wname != w1.name:
print(wname,w1.name)
assert False
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, tf.nn.l2_loss(w1)*L2)
return w1
评论列表
文章目录