cnn_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:cgp-cnn 作者: sg-nm 项目源码 文件源码
def __call__(self, x, train):
        param_num = 0
        for name, f in self.forward:
            if 'conv1' in name:
                x = getattr(self, name)(x)
                param_num += (f.W.shape[0]*f.W.shape[2]*f.W.shape[3]*f.W.shape[1]+f.W.shape[0])
            elif 'bn1' in name:
                x = getattr(self, name)(x, not train)
                param_num += x.data.shape[1]*2
        return (F.relu(x), param_num)


# [(CONV -> Batch -> ReLU -> CONV -> Batch) + (x)]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号