def forward(self, x):
# Input
x = self.input_conv(x)
# Network
skips = []
for i, (block, scale) in enumerate(zip(self.blocks, self.scales)):
if i < self.n_layers:
x = concat([block(x), x])
skips.append(x)
x = scale(x)
elif i == self.n_layers:
x = block(x)
else:
x = block(concat([scale(x), skips[2 * self.n_layers - i]]))
# Output
x = self.output_conv(x)
b = self.hps.patch_border
return F.sigmoid(x[:, :, b:-b, b:-b])
评论列表
文章目录