def create_model(session, config, forward_only):
"""Create translation model and initialize or load parameters in session."""
dtype = tf.float32
optimizer = None
if not forward_only:
optimizer = tf.train.AdamOptimizer(config.learning_rate)
if config.activation == "elu":
activation = tf.nn.elu
elif config.activation == "prelu":
activation = prelu
else:
activation = tf.identity
weight_initializer = tf.orthogonal_initializer if config.orthogonal_initializer else tf.uniform_unit_scaling_initializer
bias_initializer = tf.zeros_initializer
model = seq2seq_model.Seq2SeqModel(
config.en_vocab_size,
config.fr_vocab_size,
config.buckets,
config.size,
config.num_layers,
config.latent_dim,
config.max_gradient_norm,
config.batch_size,
config.learning_rate,
config.kl_min,
config.word_dropout_keep_prob,
config.anneal,
config.use_lstm,
optimizer=optimizer,
activation=activation,
forward_only=forward_only,
feed_previous=config.feed_previous,
bidirectional=config.bidirectional,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
iaf=config.iaf,
dtype=dtype)
ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
if not FLAGS.new and ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
session.run(tf.global_variables_initializer())
return model
vrae.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录