def forward(self, inpt):
batch_size = self.batch_size
f0 = self.features(inpt[:, 0])
f0 = f0.view(batch_size, -1)
f1 = self.features(inpt[:, 1])
f1 = f1.view(batch_size, -1)
# f2 = self.features(inpt[:, 2])
# f2 = f2.view(batch_size, -1)
#
# f3 = self.features(inpt[:, 3])
# f3 = f3.view(batch_size, -1)
#
# f4 = self.features(inpt[:, 4])
# f4 = f4.view(batch_size, -1)
#
# f = torch.stack((f0, f1, f2, f3, f4), dim=0).view(self.seq_length, batch_size, -1)
f = torch.cat((f0, f1), dim=1)
# _, hn = self.rnn(f, self.hidden)
# hn = hn[self.gru_layer - 1].view(batch_size, -1)
# hn = self.relu(hn)
# hn = self.dropout(hn)
# hn = self.regressor(hn)
hn = self.regressor(f)
trans = self.trans_regressor(hn)
# trans_norm = torch.norm(trans, dim=1)
# trans = torch.div(trans, torch.cat((trans_norm, trans_norm, trans_norm), dim=1))
scale = self.scale_regressor(hn)
rotation = self.rotation_regressor(hn)
return trans, scale, rotation
评论列表
文章目录