def __call__(self, x1, x2):
xp = self.xp
out_size = self.out_size
batch_size, len1, dim1 = x1.shape
if not self.nobias[0]:
x1 = F.concat((x1, xp.ones((batch_size, len1, 1),
dtype=xp.float32)), axis=2)
dim1 += 1
len2, dim2 = x2.shape[1:]
if not self.nobias[1]:
x2 = F.concat((x2, xp.ones((batch_size, len2, 1),
dtype=xp.float32)), axis=2)
dim2 += 1
x1_reshaped = F.reshape(x1, (batch_size * len1, dim1))
W_reshaped = F.reshape(F.transpose(self.W, (0, 2, 1)),
(dim1, out_size * dim2))
affine = F.reshape(F.matmul(x1_reshaped, W_reshaped),
(batch_size, len1 * out_size, dim2))
biaffine = F.transpose(
F.reshape(batch_matmul(affine, x2, transb=True),
(batch_size, len1, out_size, len2)),
(0, 1, 3, 2))
if not self.nobias[2]:
biaffine += F.broadcast_to(self.b, biaffine.shape)
return biaffine
评论列表
文章目录