def __call__(self, e1, e2):
ele2 = F.reshape(
F.batch_matmul(e1[:,:,None], e2[:,None,:]), (-1, self.in_size1 * self.in_size2))
res = F.matmul(ele2,
F.reshape(self.W, (self.in_size1 * self.in_size2, self.out_size))) + \
F.matmul(e1, self.V1) + \
F.matmul(e2, self.V2)
res, bias = F.broadcast(res, self.b)
return res + bias
评论列表
文章目录