def align(hid_align, h_dec, scope):
h_dec_align = linear3(h_dec, dim_align, "h_dec_align_"+scope) #batch_size x dimAlign
h_dec_align = tf.reshape(h_dec_align,[batch_size,1,dim_align])
h_dec_align_tiled = tf.tile(h_dec_align, [1, sentence_length, 1])
all_align = tf.tanh(h_dec_align + hid_align)
with tf.variable_scope("v_align_"+scope, reuse = DO_SHARE):
v_align=tf.get_variable("v_align_"+scope, [dim_align], initializer=tf.constant_initializer(0.0))
e_t = all_align * v_align
e_t = tf.reduce_sum(e_t, 2)
# normalise
alpha = tf.nn.softmax(e_t) # batch_size x sentence_length
alpha_t = tf.reshape(alpha, [batch_size, sentence_length, 1])
alpha_tile = tf.tile(alpha_t, [1, 1, 2*y_enc_size])
s_t = tf.multiply(alpha_tile, h_t_lang)
s_t = tf.reduce_sum(s_t, 1)
return s_t,alpha
评论列表
文章目录