def gather_forced_att_logits(encoder_input_symbols, encoder_decoder_vocab_map,
att_logit, batch_size, attn_length,
target_vocab_size):
"""Gathers attention weights as logits for forced attention."""
flat_input_symbols = tf.reshape(encoder_input_symbols, [-1])
flat_label_symbols = tf.gather(encoder_decoder_vocab_map,
flat_input_symbols)
flat_att_logits = tf.reshape(att_logit, [-1])
flat_range = tf.to_int64(tf.range(tf.shape(flat_label_symbols)[0]))
batch_inds = tf.floordiv(flat_range, attn_length)
position_inds = tf.mod(flat_range, attn_length)
attn_vocab_inds = tf.transpose(tf.pack(
[batch_inds, position_inds, tf.to_int64(flat_label_symbols)]))
# Exclude indexes of entries with flat_label_symbols[i] = -1.
included_flat_indexes = tf.reshape(tf.where(tf.not_equal(
flat_label_symbols, -1)), [-1])
included_attn_vocab_inds = tf.gather(attn_vocab_inds,
included_flat_indexes)
included_flat_att_logits = tf.gather(flat_att_logits,
included_flat_indexes)
sparse_shape = tf.to_int64(tf.pack(
[batch_size, attn_length, target_vocab_size]))
sparse_label_logits = tf.SparseTensor(included_attn_vocab_inds,
included_flat_att_logits, sparse_shape)
forced_att_logit_sum = tf.sparse_reduce_sum(sparse_label_logits, [1])
forced_att_logit = tf.reshape(forced_att_logit_sum,
[-1, target_vocab_size])
return forced_att_logit
评论列表
文章目录