worker.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ai-copywriter 作者: ematvey 项目源码 文件源码
def create_model(session, restore_only=False):
  # with bidirectional encoder, decoder state size should be
  # 2x encoder state size
  is_training = tf.placeholder(dtype=tf.bool, name='is_training')

  encoder_cell = LSTMCell(64)
  encoder_cell = MultiRNNCell([encoder_cell]*5)
  decoder_cell = LSTMCell(128)
  decoder_cell = MultiRNNCell([decoder_cell]*5)
  model = Seq2SeqModel(encoder_cell=encoder_cell,
                       decoder_cell=decoder_cell,
                       vocab_size=wiki.vocab_size,
                       embedding_size=300,
                       attention=True,
                       bidirectional=True,
                       is_training=is_training,
                       device=args.device,
                       debug=False)

  saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1)
  checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
  if checkpoint:
    print("Reading model parameters from %s" % checkpoint.model_checkpoint_path)
    saver.restore(session, checkpoint.model_checkpoint_path)
  elif restore_only:
    raise FileNotFoundError("Cannot restore model")
  else:
    print("Created model with fresh parameters")
    session.run(tf.global_variables_initializer())
  tf.get_default_graph().finalize()
  return model, saver
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号