RegNet.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:PoseNet 作者: bellatoris 项目源码 文件源码
def forward(self, inpt):
        batch_size = self.batch_size
        f0 = self.features(inpt[:, 0])
        f0 = f0.view(batch_size, -1)

        f1 = self.features(inpt[:, 1])
        f1 = f1.view(batch_size, -1)

        # f2 = self.features(inpt[:, 2])
        # f2 = f2.view(batch_size, -1)
        #
        # f3 = self.features(inpt[:, 3])
        # f3 = f3.view(batch_size, -1)
        #
        # f4 = self.features(inpt[:, 4])
        # f4 = f4.view(batch_size, -1)
        #
        # f = torch.stack((f0, f1, f2, f3, f4), dim=0).view(self.seq_length, batch_size, -1)

        f = torch.cat((f0, f1), dim=1)

        # _, hn = self.rnn(f, self.hidden)
        # hn = hn[self.gru_layer - 1].view(batch_size, -1)
        # hn = self.relu(hn)
        # hn = self.dropout(hn)
        # hn = self.regressor(hn)
        hn = self.regressor(f)

        trans = self.trans_regressor(hn)

        # trans_norm = torch.norm(trans, dim=1)
        # trans = torch.div(trans, torch.cat((trans_norm, trans_norm, trans_norm), dim=1))

        scale = self.scale_regressor(hn)
        rotation = self.rotation_regressor(hn)

        return trans, scale, rotation
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号