seq2seq_helpers.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DeepDeepParser 作者: janmbuys 项目源码 文件源码
def gather_prev_stack_aux_state_index(pointer_vals, prev_index, transition_state, 
                                      batch_size):
  """Gather new prev state index for aux rnn: as for main, but zero if shift."""
  new_pointer_vals = tf.reshape(pointer_vals, [-1, 1])

  # Helper tensors.
  prev_vals = tf.reshape(tf.fill(
      tf.pack([batch_size]), prev_index), [-1, 1])
  trans_inds = tf.transpose(tf.pack(
      [tf.range(batch_size), transition_state]))
  batch_zeros = tf.reshape(tf.zeros(
            tf.pack([batch_size]), dtype=tf.int32), [-1, 1])

  # Gather new prev state for aux tf.nn.
  # State inds dimension [batch_size, NUM_TR_STATES]
  state_inds = tf.concat(1, 
      [prev_vals, batch_zeros] + [prev_vals]*4 + [new_pointer_vals, prev_vals])
  prev_state_index = tf.gather_nd(state_inds, trans_inds)
  return prev_state_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号