def make_attention(target_embed, encoder_output, decoder_hidden, layer_idx):
with tf.variable_scope("attention_layer_" + str(layer_idx)):
embed_size = target_embed.get_shape().as_list()[-1] #k
dec_hidden_proj = linear_mapping_weightnorm(decoder_hidden, embed_size, var_scope_name="linear_mapping_att_query") # M*N1*k1 --> M*N1*k
dec_rep = (dec_hidden_proj + target_embed) * tf.sqrt(0.5)
encoder_output_a = encoder_output.outputs
encoder_output_c = encoder_output.attention_values # M*N2*K
att_score = tf.matmul(dec_rep, encoder_output_a, transpose_b=True) #M*N1*K ** M*N2*K --> M*N1*N2
att_score = tf.nn.softmax(att_score)
length = tf.cast(tf.shape(encoder_output_c), tf.float32)
att_out = tf.matmul(att_score, encoder_output_c) * length[1] * tf.sqrt(1.0/length[1]) #M*N1*N2 ** M*N2*K --> M*N1*k
att_out = linear_mapping_weightnorm(att_out, decoder_hidden.get_shape().as_list()[-1], var_scope_name="linear_mapping_att_out")
return att_out
评论列表
文章目录