test_linear_network.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:shoelace 作者: rjagerman 项目源码 文件源码
def test_linear_network():

    # To ensure repeatability of experiments
    np.random.seed(1042)

    # Load data set
    dataset = get_dataset(True)
    iterator = LtrIterator(dataset, repeat=True, shuffle=True)
    eval_iterator = LtrIterator(dataset, repeat=False, shuffle=False)

    # Create neural network with chainer and apply our loss function
    predictor = links.Linear(None, 1)
    loss = Ranker(predictor, listnet)

    # Build optimizer, updater and trainer
    optimizer = optimizers.Adam(alpha=0.2)
    optimizer.setup(loss)
    updater = training.StandardUpdater(iterator, optimizer)
    trainer = training.Trainer(updater, (10, 'epoch'))

    # Evaluate loss before training
    before_loss = eval(loss, eval_iterator)

    # Train neural network
    trainer.run()

    # Evaluate loss after training
    after_loss = eval(loss, eval_iterator)

    # Assert precomputed values
    assert_almost_equal(before_loss, 0.26958397)
    assert_almost_equal(after_loss, 0.2326711)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号