def __call__(self, ht, xs, d_bar_s_1):
#ht:encoder?????????????????
#batch_size * n_words * in_size
#xs:??????
if d_bar_s_1 == None:
d_bar_s_1 = np.zeros(self.in_size)
ht_T = list(map(F.transpose, ht))
phi_ht = list(map(W1, ht_T))
d_s = rnn(d_bar_s_1, y_s_1)
phi_d = F.transpose_sequence(W2(F.transpose_sequence(d_s)))
u_st = list(map(lambda x: phi_d*x, phi_ht)) #(4)
sum_u = F.sum(u_st)
alpha_st = list(map(lambda x:x/sum_u, u_st)) #(3)
z_s = F.argmax(alpha_st, axis=0)
c_s = F.sum(list(map(lambda x,y:x*y , alpha_st, ht))) #(2)
d_bar_s = F.relu(W3(F.concat([c_s, d_s])))
return d_bar_s, d_s, c_s, z_s
评论列表
文章目录