worker.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:hierarchical-attention-networks 作者: ematvey 项目源码 文件源码
def HAN_model_1(session, restore_only=False):
  """Hierarhical Attention Network"""
  import tensorflow as tf
  try:
    from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, DropoutWrapper
  except ImportError:
    MultiRNNCell = tf.nn.rnn_cell.MultiRNNCell
    GRUCell = tf.nn.rnn_cell.GRUCell
  from bn_lstm import BNLSTMCell
  from HAN_model import HANClassifierModel

  is_training = tf.placeholder(dtype=tf.bool, name='is_training')

  cell = BNLSTMCell(80, is_training) # h-h batchnorm LSTMCell
  # cell = GRUCell(30)
  cell = MultiRNNCell([cell]*5)

  model = HANClassifierModel(
      vocab_size=vocab_size,
      embedding_size=200,
      classes=classes,
      word_cell=cell,
      sentence_cell=cell,
      word_output_size=100,
      sentence_output_size=100,
      device=args.device,
      learning_rate=args.lr,
      max_grad_norm=args.max_grad_norm,
      dropout_keep_proba=0.5,
      is_training=is_training,
  )

  saver = tf.train.Saver(tf.global_variables())
  checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
  if checkpoint:
    print("Reading model parameters from %s" % checkpoint.model_checkpoint_path)
    saver.restore(session, checkpoint.model_checkpoint_path)
  elif restore_only:
    raise FileNotFoundError("Cannot restore model")
  else:
    print("Created model with fresh parameters")
    session.run(tf.global_variables_initializer())
  # tf.get_default_graph().finalize()
  return model, saver
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号