mlp-digits.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def main(max_iter):
    # prepare
    npdl.utils.random.set_seed(1234)

    # data
    digits = load_digits()

    X_train = digits.data
    X_train /= np.max(X_train)

    Y_train = digits.target
    n_classes = np.unique(Y_train).size

    # model
    model = npdl.model.Model()
    model.add(npdl.layers.Dense(n_out=500, n_in=64, activation=npdl.activations.ReLU()))
    model.add(npdl.layers.Dense(n_out=n_classes, activation=npdl.activations.Softmax()))
    model.compile(loss=npdl.objectives.SCCE(), optimizer=npdl.optimizers.SGD(lr=0.005))

    # train
    model.fit(X_train, npdl.utils.data.one_hot(Y_train), max_iter=max_iter, validation_split=0.1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号