def __call__(self, a_list, state, batch_size, xp):
e_list = []
sum_e = xp.zeros((batch_size, 1), dtype=xp.float32)
for a in a_list:
w = reshape(batch_matmul(state['h2'], a, transa=True), (batch_size, 1))
w.data = xp.clip(w.data, -40, 40)
e = exp(w)
e_list.append(e)
sum_e = sum_e + e
context = xp.zeros((batch_size, self.hidden_size), dtype=xp.float32)
for a, e in zip(a_list, e_list):
e /= sum_e
context = context + reshape(batch_matmul(a, e), (batch_size, self.hidden_size))
return context, e_list, sum_e
评论列表
文章目录