VGG_predict.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:kaggle-dsg-qualification 作者: Ignotus 项目源码 文件源码
def make_predictions(shape, model):

    train_data, train_ids, valid_data, valid_labels, test_data, test_ids = p.get_roof_data(shape=(shape,shape))

    print '\tInitializing model'
    opt = Adagrad(lr = LR)
    model = build_model(opt, model, shape)

    print '\tCreating predictions'
    pred = model.predict_classes(test_data, 
                          batch_size = 20, 
                          verbose=0)

    pred_valid = model.predict_classes(valid_data, 
                          batch_size = 20, 
                          verbose=0)

    pred = np.array([x + 1 for x in list(pred)])
    pred_valid = np.array([x + 1 for x in list(pred_valid)])
    print '\tWriting to file'
    make_prediction_file(test_ids, pred,'vgg_predictions', 
                         valid_labels= valid_labels,
                         valid_predictions= pred_valid)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号