def forward(self, embedding):
def act(x):
return F.relu(x, inplace=True)
def up(x):
m = nn.UpsamplingNearest2d(scale_factor=2)
return m(x)
x_ae = embedding # Bx256
x_ae = act(self.ae_fc1_bn(self.ae_fc1(x_ae))) # 128x3x5
x_ae = x_ae.view(-1, 128, 3, 5)
x_ae = up(x_ae) # 6x10
x_ae = act(self.ae_c1_bn(self.ae_c1(x_ae))) # 6x10
x_ae = up(x_ae) # 12x20
x_ae = act(self.ae_c2_bn(self.ae_c2(x_ae))) # 12x20 -> 10x20
x_ae = F.pad(x_ae, (0, 0, 1, 0)) # 11x20
x_ae = up(x_ae) # 22x40
x_ae = act(self.ae_c3_bn(self.ae_c3(x_ae))) # 22x40
x_ae = up(x_ae) # 44x80
x_ae = F.pad(x_ae, (0, 0, 1, 0)) # add 1px at top (from 44 to 45)
x_ae = F.sigmoid(self.ae_c4(x_ae))
return x_ae
评论列表
文章目录