shared_util.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:wip-constrained-extractor 作者: brain-research 项目源码 文件源码
def deep_birnn(hps, inputs, sequence_length, num_layers=1):
  """Efficient deep bidirectional rnn.

  Args:
    hps: bag of hyperparameters.
    inputs: [batch, steps, units] tensor of input embeddings for RNN.
    sequence_length: number of steps for each inputs.
    num_layers: depth of RNN.

  Returns:
    Outputs of RNN.
  """
  sequence_length = sequence_length
  sequence_length_mask = tf.expand_dims(
      create_mask(sequence_length, hps.num_art_steps), 2)
  for j in xrange(num_layers):

    with tf.variable_scope("birnn_fwd_%d" % j):
      w = tf.get_variable(
          "w", [hps.word_embedding_size + hps.hidden_size, 4 * hps.hidden_size])
      b = tf.get_variable("b", [4 * hps.hidden_size])
      split_inputs = [tf.reshape(t, [hps.batch_size, -1])
                      for t in tf.split(1, hps.num_art_steps, inputs)]
      (_, _, _, _, _, _, h) = block_lstm(
          tf.to_int64(hps.num_art_steps), split_inputs, w, b, forget_bias=1.0)
      fwd_outs = h
      fwd_outs = tf.concat(1, [tf.expand_dims(fwdo, 1) for fwdo in fwd_outs])
      fwd_outs *= sequence_length_mask

    with tf.variable_scope("birnn_bwd_%d" % j):
      w = tf.get_variable(
          "w", [hps.word_embedding_size + hps.hidden_size, 4 * hps.hidden_size])
      b = tf.get_variable("b", [4 * hps.hidden_size])
      if sequence_length is not None:
        rev_inputs = tf.reverse_sequence(inputs, tf.to_int64(sequence_length),
                                         1)
      else:
        rev_inputs = tf.reverse(inputs, 1)

      split_rev_inputs = [tf.reshape(t, [hps.batch_size, -1])
                          for t in tf.split(1, hps.num_art_steps, rev_inputs)]
      (_, _, _, _, _, _, h) = block_lstm(
          tf.to_int64(hps.num_art_steps),
          split_rev_inputs,
          w,
          b,
          forget_bias=1.0)
      bwd_outs = h
      bwd_outs = tf.concat(1, [tf.expand_dims(bwdo, 1) for bwdo in bwd_outs])
      bwd_outs *= sequence_length_mask

      if sequence_length is not None:
        rev_bwd_outs = tf.reverse_sequence(bwd_outs,
                                           tf.to_int64(sequence_length), 1)
      else:
        rev_bwd_outs = tf.reverse(bwd_outs, 1)

    inputs = tf.concat(2, [fwd_outs, rev_bwd_outs])

    return inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号