resnet.py 文件源码

python
阅读 52 收藏 0 点赞 0 评论 0

项目:sesame-paste-noodle 作者: aissehust 项目源码 文件源码
def forward(self, inputtensor):
        #print('resnet.forward.shape: {}'.format(inputtensor[0].ndim))
        o1 = self.conv1.forward(inputtensor)
        o2 = self.bn1.forward(o1)
        o3 = self.relu1.forward(o2)
        o4 = self.conv2.forward(o3)
        o5 = self.bn2.forward(o4)

        if self.increaseDim:
            subx = T.signal.pool.pool_2d(inputtensor[0], (2,2), ignore_border=True)
            #print('resnet.forward.subx.ndim: {}'.format(subx.ndim))
            retx = T.zeros_like(subx)
            #print('resnet.forward.retx.ndim: {}'.format(retx.ndim))
            sumx = T.concatenate([subx, retx], axis=1)
            #print('resnet.forward.sumx.ndim: {}'.format(sumx.ndim))
            out = self.relu2.forward([o5[0]+sumx,])
            #print('resnet.forward.out.ndim: {}'.format(out[0].ndim))
        else:
            out = self.relu2.forward([o5[0]+inputtensor[0],])

        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号