def _encode(self, x_list):
batch_size = len(x_list[0])
source_length = len(x_list)
# Encoding
fc = bc = f = b = _zeros((batch_size, self.hidden_size))
i_list = [self.x_i(_mkivar(x)) for x in x_list]
f_list = []
b_list = []
for i in i_list:
fc, f = F.lstm(fc, self.i_f(i) + self.f_f(f))
f_list.append(f)
for i in reversed(i_list):
bc, b = F.lstm(bc, self.i_b(i) + self.b_b(b))
b_list.append(b)
b_list.reverse()
# Making concatenated matrix
# {f,b}_mat: shape = [batch, srclen, hidden]
f_mat = F.concat([F.expand_dims(f, 1) for f in f_list], 1)
b_mat = F.concat([F.expand_dims(b, 1) for b in b_list], 1)
# fb_mat: shape = [batch, srclen, 2 * hidden]
fb_mat = F.concat([f_mat, b_mat], 2)
# fbe_mat: shape = [batch * srclen, atten]
fbe_mat = self.fb_e(
F.reshape(fb_mat, [batch_size * source_length, 2 * self.hidden_size]))
return fb_mat, fbe_mat, fc, bc, f_list[-1], b_list[0]
评论列表
文章目录