def get_weight_variable(shape, name=None, type='xavier_uniform', regularize=True, **kwargs):
initialise_from_constant = False
if type == 'xavier_uniform':
initial = xavier_initializer(uniform=True, dtype=tf.float32)
elif type == 'xavier_normal':
initial = xavier_initializer(uniform=False, dtype=tf.float32)
elif type == 'he_normal':
initial = variance_scaling_initializer(uniform=False, factor=2.0, mode='FAN_IN', dtype=tf.float32)
elif type == 'he_uniform':
initial = variance_scaling_initializer(uniform=True, factor=2.0, mode='FAN_IN', dtype=tf.float32)
elif type == 'caffe_uniform':
initial = variance_scaling_initializer(uniform=True, factor=1.0, mode='FAN_IN', dtype=tf.float32)
elif type == 'simple':
stddev = kwargs.get('stddev', 0.02)
initial = tf.truncated_normal(shape, stddev=stddev, dtype=tf.float32)
initialise_from_constant = True
elif type == 'bilinear':
weights = _bilinear_upsample_weights(shape)
initial = tf.constant(weights, shape=shape, dtype=tf.float32)
initialise_from_constant = True
else:
raise ValueError('Unknown initialisation requested: %s' % type)
if name is None: # This keeps to option open to use unnamed Variables
weight = tf.Variable(initial)
else:
if initialise_from_constant:
weight = tf.get_variable(name, initializer=initial)
else:
weight = tf.get_variable(name, shape=shape, initializer=initial)
if regularize:
tf.add_to_collection('weight_variables', weight)
return weight
评论列表
文章目录