def batched_gram(self, fmap):
# (batch, featuremaps, height*width)
fmap=fmap.flatten(ndim=3)
# The T.prod term can't be taken outside as a T.mean in style_loss(), since the width and height of the image might vary
if self.net_type == 0:
return T.batched_dot(fmap, fmap.dimshuffle(0,2,1))/T.prod(fmap.shape[-2:])
elif self.net_type == 1:
return T.batched_dot(fmap, fmap.dimshuffle(0,2,1))/T.prod(fmap.shape[-1])
评论列表
文章目录