vrae.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Variational-Recurrent-Autoencoder-Tensorflow 作者: Chung-I 项目源码 文件源码
def create_model(session, config, forward_only):
  """Create translation model and initialize or load parameters in session."""
  dtype = tf.float32
  optimizer = None
  if not forward_only:
    optimizer = tf.train.AdamOptimizer(config.learning_rate)
  if config.activation == "elu":
    activation = tf.nn.elu
  elif config.activation == "prelu":
    activation = prelu
  else:
    activation = tf.identity

  weight_initializer = tf.orthogonal_initializer if config.orthogonal_initializer else tf.uniform_unit_scaling_initializer
  bias_initializer = tf.zeros_initializer

  model = seq2seq_model.Seq2SeqModel(
      config.en_vocab_size,
      config.fr_vocab_size,
      config.buckets,
      config.size,
      config.num_layers,
      config.latent_dim,
      config.max_gradient_norm,
      config.batch_size,
      config.learning_rate,
      config.kl_min,
      config.word_dropout_keep_prob,
      config.anneal,
      config.use_lstm,
      optimizer=optimizer,
      activation=activation,
      forward_only=forward_only,
      feed_previous=config.feed_previous,
      bidirectional=config.bidirectional,
      weight_initializer=weight_initializer,
      bias_initializer=bias_initializer,
      iaf=config.iaf,
      dtype=dtype)
  ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
  if not FLAGS.new and ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
    print("Created model with fresh parameters.")
    session.run(tf.global_variables_initializer())
  return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号