def forward(self, x):
# don't need resnet_feature_2 as it is too large
_, resnet_feature_3, resnet_feature_4, resnet_feature_5 = self.resnet(x)
pyramid_feature_6 = self.pyramid_transformation_6(resnet_feature_5)
pyramid_feature_7 = self.pyramid_transformation_7(F.relu(pyramid_feature_6))
pyramid_feature_5 = self.pyramid_transformation_5(resnet_feature_5)
pyramid_feature_4 = self.pyramid_transformation_4(resnet_feature_4)
upsampled_feature_5 = self._upsample(pyramid_feature_5, pyramid_feature_4)
pyramid_feature_4 = self.upsample_transform_1(
torch.add(upsampled_feature_5, pyramid_feature_4)
)
pyramid_feature_3 = self.pyramid_transformation_3(resnet_feature_3)
upsampled_feature_4 = self._upsample(pyramid_feature_4, pyramid_feature_3)
pyramid_feature_3 = self.upsample_transform_2(
torch.add(upsampled_feature_4, pyramid_feature_3)
)
return (pyramid_feature_3,
pyramid_feature_4,
pyramid_feature_5,
pyramid_feature_6,
pyramid_feature_7)
评论列表
文章目录