conv_encoder_utils.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def make_attention(target_embed, encoder_output, decoder_hidden, layer_idx):
  with tf.variable_scope("attention_layer_" + str(layer_idx)):
    embed_size = target_embed.get_shape().as_list()[-1]      #k
    dec_hidden_proj = linear_mapping_weightnorm(decoder_hidden, embed_size, var_scope_name="linear_mapping_att_query")  # M*N1*k1 --> M*N1*k
    dec_rep = (dec_hidden_proj + target_embed) * tf.sqrt(0.5)

    encoder_output_a = encoder_output.outputs
    encoder_output_c = encoder_output.attention_values    # M*N2*K

    att_score = tf.matmul(dec_rep, encoder_output_a, transpose_b=True)  #M*N1*K  ** M*N2*K  --> M*N1*N2
    att_score = tf.nn.softmax(att_score)        

    length = tf.cast(tf.shape(encoder_output_c), tf.float32)
    att_out = tf.matmul(att_score, encoder_output_c) * length[1] * tf.sqrt(1.0/length[1])    #M*N1*N2  ** M*N2*K   --> M*N1*k

    att_out = linear_mapping_weightnorm(att_out, decoder_hidden.get_shape().as_list()[-1], var_scope_name="linear_mapping_att_out")
  return att_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号