def assert_spark_df(df):
assert isinstance(df, pyspark.sql.dataframe.DataFrame), "Not a Spark DF"
python类sql()的实例源码
def spark_context(self, application_name):
"""Create a spark context given the parameters configured in this class.
The caller is responsible for calling ``.close`` on the resulting spark context
Parameters
----------
application_name : string
Returns
-------
sc : SparkContext
"""
# initialize the spark configuration
self._init_spark()
import pyspark
import pyspark.sql
# initialize conf
spark_conf = pyspark.SparkConf()
for k, v in self._spark_conf_helper._conf_dict.items():
spark_conf.set(k, v)
log.info("Starting SparkContext")
return pyspark.SparkContext(appName=application_name, conf=spark_conf)
def spark_session(self, application_name):
sc = self.spark_context(application_name)
from pyspark.sql import SparkSession
return SparkSession(sc)
def with_sql_context(application_name, conf=None):
"""Context manager for a spark context
Returns
-------
sc : SparkContext
sql_context: SQLContext
Examples
--------
Used within a context manager
>>> with with_sql_context("MyApplication") as (sc, sql_context):
... import pyspark
... # Do stuff
... pass
"""
if conf is None:
conf = default_configuration
assert isinstance(conf, SparkConfiguration)
sc = conf.spark_context(application_name)
import pyspark.sql
try:
yield sc, pyspark.sql.SQLContext(sc)
finally:
sc.stop()
def calcAll(genotypeRDD, gwasRDD, thresholdlist, logsnp, samplenum,calltableRDD=False):
logger.info("Started calculating PRS at each threshold")
prsMap={}
thresholdNoMaxSorted=sorted(thresholdlist, reverse=True)
thresholdmax=max(thresholdlist)
idlog={}
start=time.time()
for threshold in thresholdNoMaxSorted:
tic=time.time()
gwasFilteredBC=sc.broadcast(filterGWASByP_DF(GWASdf=gwasRDD, pcolumn=gwas_p, idcolumn=gwas_id, oddscolumn=gwas_or, pHigh=threshold, logOdds=log_or))
#gwasFiltered=spark.sql("SELECT snpid, gwas_or_float FROM gwastable WHERE gwas_p_float < {:f}".format(threshold)
logger.info("Filtered GWAS at threshold of {}. Time spent : {:.1f} seconds".format(str(threshold), time.time()-tic))
checkpoint=time.time()
filteredgenotype=genotypeRDD.filter(lambda line: line[0] in gwasFilteredBC.value)
assert calltableRDD, "Error, calltable must be provided"
filteredcalltable=calltableRDD.filter(lambda line: line[0] in gwasFilteredBC.value )
if not filteredgenotype.isEmpty():
#assert filteredcalltable.count()==filteredgenotype.count(), "Error, call table have different size from genotype"
if logsnp:
idlog[threshold]=filteredgenotype.map(lambda line:line[0]).collect()
prsMap[threshold]=calcPRSFromGeno(filteredgenotype, gwasFilteredBC.value,samplenum=samplenum, calltable=filteredcalltable)
logger.info("Finished calculating PRS at threshold of {}. Time spent : {:.1f} seconds".format(str(threshold), time.time()-checkpoint))
else:
logger.warn("No SNPs left at threshold {}" .format(threshold))
return prsMap, idlog
def register(sc):
java_import(sc._jvm, "com.amazonaws.services.glue.*")
java_import(sc._jvm, "com.amazonaws.services.glue.schema.*")
java_import(sc._jvm, "org.apache.spark.sql.glue.GlueContext")
java_import(sc._jvm, "com.amazonaws.services.glue.util.JsonOptions")
java_import(sc._jvm, "org.apache.spark.sql.glue.util.SparkUtility")
java_import(sc._jvm, 'org.apache.spark.sql.glue.gluefunctions')
java_import(sc._jvm, "com.amazonaws.services.glue.util.Job")
java_import(sc._jvm, "com.amazonaws.services.glue.util.AWSConnectionUtils")