segmentation.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:crankshaft 作者: CartoDB 项目源码 文件源码
def predict_segment(model, features, target_query):
    """
    Use the provided model to predict the values for the new feature set
        Input:
            @param model: The pretrained model
            @features: A list of features to use in the model prediction (list of column names)
            @target_query: The query to run to obtain the data to predict on and the cartdb_ids associated with it.
    """

    batch_size = 1000
    joined_features = ','.join(['"{0}"::numeric'.format(a) for a in features])

    try:
        cursor = plpy.cursor('SELECT Array[{joined_features}] As features FROM ({target_query}) As a'.format(
            joined_features=joined_features,
            target_query=target_query))
    except Exception, e:
        plpy.error('Failed to build segmentation model: %s' % e)

    results = []

    while True:
        rows = cursor.fetch(batch_size)
        if not rows:
            break
        batch = np.row_stack([np.array(row['features'], dtype=float) for row in rows])

        #Need to fix this. Should be global mean. This will cause weird effects
        batch = replace_nan_with_mean(batch)
        prediction = model.predict(batch)
        results.append(prediction)

    try:
        cartodb_ids = plpy.execute('''SELECT array_agg(cartodb_id ORDER BY cartodb_id) As cartodb_ids FROM ({0}) As a'''.format(target_query))[0]['cartodb_ids']
    except Exception, e:
        plpy.error('Failed to build segmentation model: %s' % e)

    return cartodb_ids, np.concatenate(results)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号