从Spark(pyspark)的管道中的StringIndexer阶段获取标签

发布于 2021-01-29 15:57:29

我正在使用Sparkpyspark并且已经pipeline设置了一堆StringIndexer对象,用于将字符串列编码为索引列:

indexers = [StringIndexer(inputCol=column, outputCol=column + '_index').setHandleInvalid('skip')
            for column in list(set(data_frame.columns) - ignore_columns)]
pipeline = Pipeline(stages=indexers)
new_data_frame = pipeline.fit(data_frame).transform(data_frame)

问题是,StringIndexer安装好每个对象后,我需要获取它们的标签列表。对于单列和StringIndexer没有管道的单列来说,这是一件容易的事。我可以labels在将索引器安装到上之后访问属性DataFrame

indexer = StringIndexer(inputCol="name", outputCol="name_index")
indexer_fitted = indexer.fit(data_frame)
labels = indexer_fitted.labels
new_data_frame = indexer_fitted.transform(data_frame)

但是,当我使用管道时,这似乎是不可能的,或者至少我不知道该怎么做。

所以我想我的问题归结为:有没有一种方法可以访问在索引过程中为每个单独的列使用的标签?

还是在这个用例中我必须放弃管道,例如循环遍历StringIndexer对象列表并手动执行?(我肯定这是可能的。但是使用管道会更好一些)

关注者
0
被浏览
162
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    示例数据和Pipeline

    from pyspark.ml.feature import StringIndexer, StringIndexerModel
    
    df = spark.createDataFrame([("a", "foo"), ("b", "bar")], ("x1", "x2"))
    
    pipeline = Pipeline(stages=[
        StringIndexer(inputCol=c, outputCol='{}_index'.format(c))
        for c in df.columns
    ])
    
    model = pipeline.fit(df)
    

    摘自stages

    # Accessing _java_obj shouldn't be necessary in Spark 2.3+
    {x._java_obj.getOutputCol(): x.labels 
    for x in model.stages if isinstance(x, StringIndexerModel)}
    
    
    
    {'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}
    

    从转换后的元数据DataFrame

    indexed = model.transform(df)
    
    {c.name: c.metadata["ml_attr"]["vals"]
    for c in indexed.schema.fields if c.name.endswith("_index")}
    
    
    
    {'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看