def batched_gram5d(self, fmap):
# (layer, batch, featuremaps, height*width)
fmap=fmap.flatten(ndim=4)
# (layer*batch, featuremaps, height*width)
fmap2=fmap.reshape((-1, fmap.shape[-2], fmap.shape[-1]))
# The T.prod term can't be taken outside as a T.mean in style_loss(), since the width and height of the image might vary
return T.batched_dot(fmap2, fmap2.dimshuffle(0,2,1)).reshape(fmap.shape)/T.prod(fmap.shape[-2:])
评论列表
文章目录