ann_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:simec 作者: cod3licious 项目源码 文件源码
def classification(dataset=0):
    # generate training and test data
    n_train = 1000
    if dataset == 0:
        X, Y = make_classification(n_samples=n_train, n_features=2, n_redundant=0, n_informative=2,
                                   random_state=1, n_clusters_per_class=1)
        rng = np.random.RandomState(2)
        X += 2 * rng.uniform(size=X.shape)
        X_test, Y_test = make_classification(n_samples=50, n_features=2, n_redundant=0, n_informative=2,
                                             random_state=1, n_clusters_per_class=1)
        X_test += 2 * rng.uniform(size=X_test.shape)
    elif dataset == 1:
        X, Y = make_moons(n_samples=n_train, noise=0.3, random_state=0)
        X_test, Y_test = make_moons(n_samples=50, noise=0.3, random_state=1)
    elif dataset == 2:
        X, Y = make_circles(n_samples=n_train, noise=0.2, factor=0.5, random_state=1)
        X_test, Y_test = make_circles(n_samples=50, noise=0.2, factor=0.5, random_state=1)
    else:
        print("dataset unknown")
        return

    # build, train, and test the model
    model = SupervisedNNModel(X.shape[1], 2, hunits=[100, 50], activations=[T.tanh, T.tanh, T.nnet.softmax], cost_fun='negative_log_likelihood',
                              error_fun='zero_one_loss', learning_rate=0.01, L1_reg=0., L2_reg=0.)
    model.fit(X, Y)
    print("Test Error: %f" % model.score(X_test, Y_test))

    # plot dataset + predictions
    plt.figure()
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
                         np.arange(y_min, y_max, 0.02))
    cm = plt.cm.RdBu
    cm_bright = ListedColormap(['#FF0000', '#0000FF'])

    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])[:, 1]

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=cm, alpha=.8)

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=cm_bright, alpha=0.6)
    # and testing points
    plt.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, cmap=cm_bright)

    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    plt.title('Classification Problem (%i)' % dataset)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号