network.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:numpy_cnn 作者: Ryanshuai 项目源码 文件源码
def __init__(self, learning_rate, input_shape, BS):#input_shape example: [BS,1,28,28]
        self.lr = learning_rate

        self.conv2d_1 = ly.conv2d(input_shape,[5,5,1,32],[1,1])
        self.relu_1 = ly.relu()
        self.max_pool_1 = ly.max_pooling(self.conv2d_1.output_shape, filter_shape=[2,2], strides=[2,2])

        self.conv2d_2 = ly.conv2d(self.max_pool_1.output_shape,[5,5,32,64],[1,1])
        self.relu_2 = ly.relu()
        self.max_pool_2 = ly.max_pooling(self.conv2d_2.output_shape, filter_shape=[2,2], strides=[2,2])

        self.flatter = ly.flatter()

        self.full_connect_1 = ly.full_connect(input_len=7*7*64,output_len=1024)
        self.relu_3 = ly.relu()
        self.dropout_1 = ly.dropout(1024)

        self.full_connect_2 = ly.full_connect(input_len=1024,output_len=10)
        self.loss_func = ly.softmax_cross_entropy_error()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号