def __init__(self, query_size, keys, values, values_length,
name='attention'):
self.attention_size = keys.get_shape().as_list()[-1]
self.keys = keys
self.values = values
self.values_length = values_length
self.query_trans = LinearOp(query_size, self.attention_size, name=name)
with tf.variable_scope(name):
self.v_att = tf.get_variable('v_att', shape=[self.attention_size],
dtype=DTYPE)
self.time_axis = 0 if TIME_MAJOR else 1
# Replace all scores for padded inputs with tf.float32.min
num_scores = tf.shape(self.keys)[self.time_axis]
scores_mask = tf.sequence_mask(
lengths=tf.to_int32(self.values_length),
maxlen=tf.to_int32(num_scores),
dtype=DTYPE)
if TIME_MAJOR:
scores_mask = tf.transpose(scores_mask)
self.scores_mask = scores_mask
评论列表
文章目录