def weight_variable(shape,
initializer=None,
init_val=None,
wd=None,
name=None,
trainable=True):
"""Initialize weights.
Args:
shape: shape of the weights, list of int
wd: weight decay
"""
log = logger.get()
if initializer is None:
initializer = tf.truncated_normal_initializer(stddev=0.01)
if init_val is None:
var = tf.Variable(initializer(shape), name=name, trainable=trainable)
else:
var = tf.Variable(init_val, name=name, trainable=trainable)
if wd:
weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
评论列表
文章目录