def __call__(self, query):
with tf.variable_scope('attention'):
# Check if the memory's batch_size is consistent with query's batch_size
query_units = query.get_shape()[-1].value
Wa = tf.get_variable(name='Wa', shape=(query_units, self.attention_units))
Va = tf.get_variable(name='Va', shape=(self.attention_units,),
initializer=tf.constant_initializer(0.0) if self.mode == 0 else tf.constant_initializer(1e-2))
b = tf.get_variable(name='b', shape=(self.attention_units,),
initializer=tf.constant_initializer(0.0) if self.mode == 0 else tf.constant_initializer(0.5))
# 1st. compute query_feat (query's repsentation in attention module)
query_feat = tf.reshape(tf.matmul(query, Wa), (-1, 1, 1, self.attention_units))
# 2nd. compute the energy for all time steps in encoder (element-wise mul then reduce)
e = tf.reduce_sum(Va * tf.nn.tanh(self.hidden_feats + query_feat + b), axis=(2,3))
# 3rd. compute the score
if self.mask is not None:
exp_e = tf.exp(e)
exp_e = exp_e * self.mask
alpha = tf.divide(exp_e, tf.reduce_sum(exp_e, axis=-1, keep_dims=True))
else:
alpha = tf.nn.softmax(e)
# 4th. get the weighted context from memory (element-wise mul then reduce)
context = tf.reshape(alpha, (tf.shape(query)[0], self.enc_length, 1, 1)) * self.memory
context = tf.reduce_sum(context, axis=(1, 2))
return context, alpha
评论列表
文章目录