SqueezeNet.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:training 作者: bddmodelcar 项目源码 文件源码
def __init__(self, n_steps=10, n_frames=2):
        super(SqueezeNet, self).__init__()

        self.n_steps = n_steps
        self.n_frames = n_frames
        self.pre_metadata_features = nn.Sequential(
            nn.Conv2d(3 * 2 * self.n_frames, 16, kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(16, 4, 8, 8)
        )
        self.post_metadata_features = nn.Sequential(
            Fire(24, 6, 12, 12),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(24, 8, 16, 16),
            Fire(32, 8, 16, 16),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(32, 12, 24, 24),
            Fire(48, 12, 24, 24),
            Fire(48, 16, 32, 32),
            Fire(64, 16, 32, 32),
        )
        final_conv = nn.Conv2d(64, self.n_steps * 2, kernel_size=1)
        self.final_output = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            # nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=5, stride=5)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal(m.weight.data, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号