def get_output_for(self, inputs, **kwargs):
s_hat_t = inputs[0]
h_hat_t = inputs[1]
# s_hat_t = s_hat_t.dimshuffle(1, 0)
# h_hat_t = h_hat_t.dimshuffle(1, 0)
H = inputs[2]
# H = H.dimshuffle(2, 0, 1)
# H_len = H.shape[-1]
# z_t 1*none*k
zt = T.dot(
self.nonlinearity(
T.dot(H, self.W_v_to_attenGate) +
T.dot(
T.dot(h_hat_t, self.W_g_to_attenGate).dimshuffle(0, 1, 'x'),
T.ones((1, self.num_inputs))
)
),
self.W_h_to_attenGate
)[:, :, 0]
vt = T.dot(
self.nonlinearity(
T.dot(
s_hat_t, self.W_s_to_attenGate
) +
T.dot(
h_hat_t, self.W_g_to_attenGate
)
),
self.W_h_to_attenGate
)
alpha_hat_t = self.nonlinearity_atten(T.concatenate(
[zt, vt],
axis=-1
))
feature = T.concatenate(
[H, s_hat_t.dimshuffle(0, 'x', 1)],
axis=1
).dimshuffle(2, 0, 1)
c_hat_t = T.sum(alpha_hat_t*feature, axis=-1)
out = T.dot(
(c_hat_t.T+h_hat_t), self.W_p
)
return nonlinearities.softmax(out)
评论列表
文章目录