def get_variable_initializer(hparams):
"""Get variable initializer from hparams."""
if hparams.initializer == "orthogonal":
return tf.orthogonal_initializer(gain=hparams.initializer_gain)
elif hparams.initializer == "uniform":
max_val = 0.1 * hparams.initializer_gain
return tf.random_uniform_initializer(-max_val, max_val)
elif hparams.initializer == "normal_unit_scaling":
return tf.variance_scaling_initializer(
hparams.initializer_gain, mode="fan_avg", distribution="normal")
elif hparams.initializer == "uniform_unit_scaling":
return tf.variance_scaling_initializer(
hparams.initializer_gain, mode="fan_avg", distribution="uniform")
else:
raise ValueError("Unrecognized initializer: %s" % hparams.initializer)
评论列表
文章目录