alexnet.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:caffe-model 作者: GeekLiB 项目源码 文件源码
def alexnet_proto(self, batch_size, phase='TRAIN'):
        n = caffe.NetSpec()
        if phase == 'TRAIN':
            source_data = self.train_data
            mirror = True
        else:
            source_data = self.test_data
            mirror = False
        n.data, n.label = L.Data(source=source_data, backend=P.Data.LMDB, batch_size=batch_size, ntop=2,
                                 transform_param=dict(crop_size=227, mean_value=[104, 117, 123], mirror=mirror))

        n.conv1, n.relu1 = conv_relu(n.data, num_output=96, kernel_size=11, stride=4)  # 96x55x55
        n.norm1 = L.LRN(n.conv1, local_size=5, alpha=0.0001, beta=0.75)
        n.pool1 = L.Pooling(n.norm1, kernel_size=3, stride=2, pool=P.Pooling.MAX)  # 96x27x27

        n.conv2, n.relu2 = conv_relu(n.pool1, num_output=256, kernel_size=5, pad=2, group=2)  # 256x27x27
        n.norm2 = L.LRN(n.conv2, local_size=5, alpha=0.0001, beta=0.75)
        n.pool2 = L.Pooling(n.norm2, kernel_size=3, stride=2, pool=P.Pooling.MAX)  # 256x13x13

        n.conv3, n.relu3 = conv_relu(n.pool2, num_output=384, kernel_size=3, pad=1)  # 384x13x13
        n.conv4, n.relu4 = conv_relu(n.conv3, num_output=384, kernel_size=3, pad=1, group=2)  # 384x13x13

        n.conv5, n.relu5 = conv_relu(n.conv4, num_output=256, kernel_size=3, pad=1, group=2)  # 256x13x13
        n.pool5 = L.Pooling(n.conv5, kernel_size=3, stride=2, pool=P.Pooling.MAX)  # 256x6x16

        n.fc6, n.relu6, n.drop6 = fc_relu_drop(n.pool5, num_output=4096)  # 4096x1x1
        n.fc7, n.relu7, n.drop7 = fc_relu_drop(n.fc6, num_output=4096)  # 4096x1x1
        n.fc8 = L.InnerProduct(n.fc7, num_output=self.classifier_num,
                               param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)],
                               weight_filler=dict(type='gaussian', std=0.01),
                               bias_filler=dict(type='constant', value=0))
        n.loss = L.SoftmaxWithLoss(n.fc8, n.label)
        if phase == 'TRAIN':
            pass
        else:
            n.accuracy_top1, n.accuracy_top5 = accuracy_top1_top5(n.fc8, n.label)

        return n.to_proto()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号