def forward(self, color, sketch): color = F.avg_pool2d(color, 16, 16) sketch = self.model(sketch) out = self.prototype(torch.cat([sketch, color], 1)) return self.out(out.view(color.size(0), -1))