def variable_with_weight_decay(name, shape, stddev, wd):
"""
Note that the Variable is initialized with a truncated normal distribution.
A weight decay is added only if one is specified.
Args:
name -> name of the variable
shape -> list of ints
stddev -> standard deviation of a truncated Gaussian
wd -> add L2Loss weight decay multiplied by this float.
If None, weight decay is not added for this Variable.
Rtns:
var -> variable tensor
"""
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
var = variable_on_cpu(name,shape,
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.mul(tf.nn.l2_loss(var),wd,name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
评论列表
文章目录