def encode(self, x_input, x_query, answer):
m = self.encode_input(x_input)
u = self.encode_query(x_query)
# print "m.data.shape", m.data.shape
# print "u.data.shape", u.data.shape
mu = functions.matmul(m, u, transb=True)
# print "mu.data.shape", mu.data.shape
# print "mu.data", mu.data
p = functions.softmax(mu)
c = self.encode_output(x_input)
# print "p.data.shape:", p.data.shape
# print "c.data.shape:", c.data.shape
# print "functions.swapaxes(c ,2, 1):", functions.swapaxes(c ,2, 1).data.shape
o = functions.matmul(functions.swapaxes(c ,1, 0), p) # (2, 50, 1)
o = functions.swapaxes(o ,1, 0) # (2, 50)
# print "u.data.shape:", u.data.shape
# print "o.data.shape:", o.data.shape
# print "u.data.shape:", u.data
# print "o.data.shape:", o.data
# print (u+o).data.shape
predict = self.W(u + o)
# print predict.data.shape
loss = functions.softmax_cross_entropy(predict, answer)
return loss
评论列表
文章目录