def forward(self, input):
x = input
ys = []
target_size = None
depth_dim = 0
for seq in self.seq_list:
# print(seq)
# print(self.outputSize)
# print('x_size:', x.size())
y = seq(x)
y_size = y.size()
# print('y_size:', y_size)
ys.append(y)
#
if target_size is None:
target_size = [0] * len(y_size)
#
for i in range(len(target_size)):
target_size[i] = max(target_size[i], y_size[i])
depth_dim += y_size[1]
target_size[1] = depth_dim
# print('target_size:', target_size)
for i in range(len(ys)):
y_size = ys[i].size()
pad_l = int((target_size[3] - y_size[3]) // 2)
pad_t = int((target_size[2] - y_size[2]) // 2)
pad_r = target_size[3] - y_size[3] - pad_l
pad_b = target_size[2] - y_size[2] - pad_t
ys[i] = F.pad(ys[i], (pad_l, pad_r, pad_t, pad_b))
output = torch.cat(ys, 1)
return output
评论列表
文章目录