plot_chainer_MLP.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:soft-dtw 作者: mblondel 项目源码 文件源码
def train(network, loss, X_tr, Y_tr, X_te, Y_te, n_epochs=30, gamma=1):
    model= Objective(network, loss=loss, gamma=gamma)

    #optimizer = optimizers.SGD()
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train = tuple_dataset.TupleDataset(X_tr, Y_tr)
    test = tuple_dataset.TupleDataset(X_te, Y_te)

    train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=True)
    test_iter = iterators.SerialIterator(test, batch_size=1, repeat=False,
                                         shuffle=False)
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (n_epochs, 'epoch'))

    trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号