def forward(self, x):
# pass to initial conv
x = self.conv1(x)
# pass through residual units
for i in range(3):
x = self.up_residual_units[i](x)
# divide stream
y = x
z = self.split_conv(x)
prev_channels = 48
# encoding
for n_blocks, channels, scale in self.encoder_frru_specs:
# maxpool bigger feature map
y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0)
# pass through encoding FRRUs
for block in range(n_blocks):
key = '_'.join(map(str,['encoding_frru', n_blocks, channels, scale, block]))
y, z = getattr(self, key)(y_pooled, z)
prev_channels = channels
# decoding
for n_blocks, channels, scale in self.decoder_frru_specs:
# bilinear upsample smaller feature map
upsample_size = torch.Size([_s*2 for _s in y.size()[-2:]])
y_upsampled = F.upsample(y, size=upsample_size, mode='bilinear')
# pass through decoding FRRUs
for block in range(n_blocks):
key = '_'.join(map(str,['decoding_frru', n_blocks, channels, scale, block]))
#print "Incoming FRRU Size: ", key, y_upsampled.shape, z.shape
y, z = getattr(self, key)(y_upsampled, z)
#print "Outgoing FRRU Size: ", key, y.shape, z.shape
prev_channels = channels
# merge streams
x = torch.cat([F.upsample(y, scale_factor=2, mode='bilinear' ), z], dim=1)
x = self.merge_conv(x)
# pass through residual units
for i in range(3):
x = self.down_residual_units[i](x)
# final 1x1 conv to get classification
x = self.classif_conv(x)
return x
评论列表
文章目录