def call(self, x, mask=None):
if K.backend() == 'tensorflow':
xt = tf.transpose(x, perm=(2, 0 ,1))
gt = tf.gather(xt, self.indices)
return tf.transpose(gt, perm=(1, 2, 0))
return x[:, :, self.indices]
文章目录