test_ml_trainer.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:BlueWhale 作者: caffe2 项目源码 文件源码
def get_prediction_dist(
    trainer,
    num_outputs=1,
    num_features=4,
    num_training_samples=100,
    num_test_datapoints=10,
    num_training_iterations=10000,
):
    test_inputs, test_outputs, _ = _train(
        trainer, num_features, num_training_samples, num_test_datapoints,
        num_outputs, num_training_iterations
    )

    predictions = trainer.score(test_inputs)
    dist = np.linalg.norm(test_outputs - predictions)
    return dist
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号