def weight(name, shape, init='he', range=1, stddev=0.33, init_val=None):
if init_val is not None:
initializer = tf.constant_initializer(init_val)
elif init == 'uniform':
initializer = tf.random_uniform_initializer(-range, range)
elif init == 'normal':
initializer = tf.random_normal_initializer(stddev = stddev)
elif init == 'he':
fan_in, _ = _get_dims(shape)
std = math.sqrt(2.0 / fan_in)
initializer = tf.random_normal_initializer(stddev = std)
elif init == 'xavier':
fan_in, fan_out = _get_dims(shape)
range = math.sqrt(6.0 / (fan_in + fan_out))
initializer = tf.random_uniform_initializer(-range, range)
else:
initializer = tf.truncated_normal_initializer(stddev = stddev)
var = tf.get_variable(name, shape, initializer = initializer)
tf.add_to_collection('l2', tf.nn.l2_loss(var))
return var
评论列表
文章目录