basic_driver.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Dense-Net 作者: achyudhk 项目源码 文件源码
def iris_svm():
    print("Initializing net for Iris dataset classification problem. . .")
    iris = load_iris()
    X = iris.data
    Y = iris.target

    dn = DenseNet(input_dim=4, optim_config={"type": "sgd", "learning_rate": 0.01}, loss_fn='svm')
    dn.addlayer("ReLU", 4)
    dn.addlayer("ReLU", 6)
    dn.addlayer("ReLU", 3)

    for i in range(1000):
        print("Iteration: ", i)
        dn.train(X, Y)

# def iris_svm_momentum():
#     print("Initializing net for Iris dataset classification problem. . .")
#     iris = load_iris()
#     X = iris.data
#     Y = iris.target
#
#     dn = DenseNet(input_dim=4, optim_config={"type": "momentum", "learning_rate": 0.01, "momentum":0.5}, loss_fn='svm')
#     dn.addlayer("ReLU", 4)
#     dn.addlayer("ReLU", 6)
#     dn.addlayer("ReLU", 3)
#
#     for i in range(1000):
#         print("Iteration: ", i)
#         dn.train(X, Y)

#two_bit_xor_sigmoid()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号