seq2seq_helpers.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DeepDeepParser 作者: janmbuys 项目源码 文件源码
def mask_decoder_only_shift(logit, thin_stack_head_next, transition_state_map,
                          logit_size, batch_size):
  """Ensures that if the stack is empty, has to GEN_STATE (shift transition)

  For each batch entry k:
    If thin_stack_head_next == 0, #alternatively, or 1.
      let logit[k][reduce_index] = -np.inf, 
    else don't change.
  """
  stack_is_empty_bool = tf.less_equal(thin_stack_head_next, 1) 
  stack_is_empty = tf.select(stack_is_empty_bool, 
                            tf.ones(tf.pack([batch_size]), dtype=tf.int32),
                            tf.zeros(tf.pack([batch_size]), dtype=tf.int32))
  stack_is_empty = tf.reshape(stack_is_empty, [-1, 1])

  # Sh and Re states are disallowed (but not root).
  state_is_disallowed_updates = tf.sparse_to_dense(
      tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE]),
      tf.pack([data_utils.NUM_TR_STATES]), 1)
  logit_states = tf.gather(transition_state_map, tf.range(logit_size))
  state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states)
  state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1])

  index_delta = tf.matmul(stack_is_empty, state_is_disallowed) # 1 if disallowed
  values = tf.pack([0, -np.inf])
  delta = tf.gather(values, index_delta)
  new_logit = logit + delta
  return new_logit
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号