def forward(self, y, z):
x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
y_prime = self.conv1(x)
y_prime = self.conv2(y_prime)
x = self.conv_res(y_prime)
upsample_size = torch.Size([_s*self.scale for _s in y_prime.shape[-2:]])
x = F.upsample(x, size=upsample_size, mode='nearest')
z_prime = z + x
return y_prime, z_prime
评论列表
文章目录