rnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:UnstableParser 作者: tdozat 项目源码 文件源码
def birnn(cell, inputs, sequence_length, initial_state_fw=None, initial_state_bw=None, ff_keep_prob=1., recur_keep_prob=1., dtype=tf.float32, scope=None):
  """"""

  # Forward direction
  with tf.variable_scope(scope or 'BiRNN_FW') as fw_scope:
    output_fw, output_state_fw = rnn(cell, inputs, sequence_length, initial_state_fw, ff_keep_prob, recur_keep_prob, dtype, scope=fw_scope)

  # Backward direction
  rev_inputs = tf.reverse_sequence(inputs, sequence_length, 1, 0)
  with tf.variable_scope(scope or 'BiRNN_BW') as bw_scope:
    output_bw, output_state_bw = rnn(cell, rev_inputs, sequence_length, initial_state_bw, ff_keep_prob, recur_keep_prob, dtype, scope=bw_scope)
  output_bw = tf.reverse_sequence(output_bw, sequence_length, 1, 0)
  # Concat each of the forward/backward outputs
  outputs = tf.concat([output_fw, output_bw], 2)

  return outputs, tf.tuple([output_state_fw, output_state_bw])

#===============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号