def get_output_for(self, inputs, **kwargs):
p_gru, q_gru, q_mask, feature = tuple(inputs)
time_p = p_gru.shape[1]
time_q = q_gru.shape[1]
p_gru_re = p_gru.dimshuffle(0, 1, 'x', 2) # (batch, time_p, 1, units)
q_gru_re = q_gru.dimshuffle(0, 'x', 1, 2) # (batch, 1, time_q, units)
gru_merge = T.tanh(p_gru_re * q_gru_re).reshape((-1, time_q, self.units)) # (batch * time_p, time_q, units)
att = T.dot(gru_merge, self.v1).reshape((-1, time_p, time_q)) # (batch, time_p, time_q)
att_q = T.dot(q_gru, self.v2).squeeze() # (batch, time_q)
att = att + att_q.dimshuffle(0, 'x', 1) + feature # (batch, time_p, time_q)
att = T.nnet.softmax(att.reshape((-1, time_q))) # (batch * time_p, time_q)
att = att.reshape((-1, time_p, time_q)) * q_mask.dimshuffle(0, 'x', 1) # (batch, time_p, time_q)
att = att / (att.sum(axis = 2, keepdims = True) + 1e-8) # (batch, time_p, time_q)
att = att.reshape((-1, time_q))
output = T.batched_dot(att, gru_merge) # (batch * time_p, units)
output = output.reshape((-1, time_p, self.units))
return output
评论列表
文章目录