def decoder(self, z, sc_feat32, sc_feat16, sc_feat8, sc_feat4):
x = z.view(-1, self.hidden_size, 1, 1)
x = self.dec_upsamp1(x)
x = torch.cat([x, sc_feat4], 1)
x = F.relu(self.dec_conv1(x))
x = self.dec_bn1(x)
x = self.dec_upsamp2(x)
x = torch.cat([x, sc_feat8], 1)
x = F.relu(self.dec_conv2(x))
x = self.dec_bn2(x)
x = self.dec_upsamp3(x)
x = torch.cat([x, sc_feat16], 1)
x = F.relu(self.dec_conv3(x))
x = self.dec_bn3(x)
x = self.dec_upsamp4(x)
x = torch.cat([x, sc_feat32], 1)
x = F.relu(self.dec_conv4(x))
x = self.dec_bn4(x)
x = self.dec_upsamp5(x)
x = F.tanh(self.dec_conv5(x))
return x
#define forward pass
评论列表
文章目录