c2_p2_f2_dropout_cross_entropy_net.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:numpy_cnn 作者: Ryanshuai 项目源码 文件源码
def __init__(self, learning_rate, input_shape):#input_shape example: [BS,1,28,28]
        self.lr = learning_rate

        # conv1:(BS,1,28,28)->(BS,6,28,28)->(BS,6,14,14)
        self.conv2d_1 = ly.conv2d(input_shape, [5, 5, 1, 6], [1, 1], 'SAME')
        self.relu_1 = ly.relu()
        self.pool_1 = ly.max_pooling(self.conv2d_1.output_shape, [2,2], [2,2], 'SAME')

        # conv2:(BS,6,14,14)->(BS,10,14,14)->(BS,10,7,7)
        self.conv2d_2 = ly.conv2d(self.pool_1.output_shape, [5, 5, 6, 10], [1, 1], 'SAME')
        self.relu_2 = ly.relu()
        self.pool_2 = ly.max_pooling(self.conv2d_2.output_shape, [2,2], [2,2], 'SAME')

        # flat:(BS,10,7,7)->(BS,490)
        self.flatter = ly.flatter()

        # fc1:(BS,490)->(BS,84)
        self.full_connect_1 = ly.full_connect(490, 84)
        self.relu_3 = ly.relu()
        self.dropout = ly.dropout(lenth=84)

        # fc2:(BS,84)->(BS,10)
        self.full_connect_2 = ly.full_connect(84, 10)

        self.loss_func = ly.softmax_cross_entropy_error()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号