seq2seq_helpers.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DeepDeepParser 作者: janmbuys 项目源码 文件源码
def gather_forced_att_logits(encoder_input_symbols, encoder_decoder_vocab_map, 
                             att_logit, batch_size, attn_length, 
                             target_vocab_size):
  """Gathers attention weights as logits for forced attention."""
  flat_input_symbols = tf.reshape(encoder_input_symbols, [-1])
  flat_label_symbols = tf.gather(encoder_decoder_vocab_map,
      flat_input_symbols)
  flat_att_logits = tf.reshape(att_logit, [-1])

  flat_range = tf.to_int64(tf.range(tf.shape(flat_label_symbols)[0]))
  batch_inds = tf.floordiv(flat_range, attn_length)
  position_inds = tf.mod(flat_range, attn_length)
  attn_vocab_inds = tf.transpose(tf.pack(
      [batch_inds, position_inds, tf.to_int64(flat_label_symbols)]))

  # Exclude indexes of entries with flat_label_symbols[i] = -1.
  included_flat_indexes = tf.reshape(tf.where(tf.not_equal(
      flat_label_symbols, -1)), [-1])
  included_attn_vocab_inds = tf.gather(attn_vocab_inds, 
      included_flat_indexes)
  included_flat_att_logits = tf.gather(flat_att_logits, 
      included_flat_indexes)

  sparse_shape = tf.to_int64(tf.pack(
      [batch_size, attn_length, target_vocab_size]))

  sparse_label_logits = tf.SparseTensor(included_attn_vocab_inds, 
      included_flat_att_logits, sparse_shape)
  forced_att_logit_sum = tf.sparse_reduce_sum(sparse_label_logits, [1])

  forced_att_logit = tf.reshape(forced_att_logit_sum, 
      [-1, target_vocab_size])

  return forced_att_logit
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号