def __call__(self, x):
xp = chainer.cuda.get_array_module(x.data)
batchsize = x.shape[0]
if self.train_weights == False and self.initial_T is not None:
self.T.W.data = self.initial_T
M = F.reshape(self.T(x), (-1, self.num_kernels, self.ndim_kernel))
M = F.expand_dims(M, 3)
M_T = F.transpose(M, (3, 1, 2, 0))
M, M_T = F.broadcast(M, M_T)
norm = F.sum(abs(M - M_T), axis=2)
eraser = F.broadcast_to(xp.eye(batchsize, dtype=x.dtype).reshape((batchsize, 1, batchsize)), norm.shape)
c_b = F.exp(-(norm + 1e6 * eraser))
o_b = F.sum(c_b, axis=2)
if self.train_weights == False:
self.initial_T = self.T.W.data
return F.concat((x, o_b), axis=1)
评论列表
文章目录