def forward(self, x):
xs = []
for i, down in enumerate(self.down):
if i == 0:
x_in = x
elif i == 1:
x_in = self.pool_top(xs[-1])
else:
x_in = self.pool(xs[-1])
x_out = down(x_in)
x_out = self.dropout2d(x_out)
xs.append(x_out)
x_out = xs[-1]
for i, (x_skip, up) in reversed(list(enumerate(zip(xs[:-1], self.up)))):
upsample = self.upsample_top if i == 0 else self.upsample
x_out = up(torch.cat([upsample(x_out), x_skip], 1))
x_out = self.dropout2d(x_out)
x_out = self.conv_final(x_out)
b = self.hps.patch_border
return F.sigmoid(x_out[:, :, b:-b, b:-b])
评论列表
文章目录