def get_initializer(params):
if params.initializer == "uniform":
max_val = params.initializer_gain
return tf.random_uniform_initializer(-max_val, max_val)
elif params.initializer == "normal":
return tf.random_normal_initializer(0.0, params.initializer_gain)
elif params.initializer == "normal_unit_scaling":
return tf.variance_scaling_initializer(params.initializer_gain,
mode="fan_avg",
distribution="normal")
elif params.initializer == "uniform_unit_scaling":
return tf.variance_scaling_initializer(params.initializer_gain,
mode="fan_avg",
distribution="uniform")
else:
raise ValueError("Unrecognized initializer: %s" % params.initializer)
评论列表
文章目录