def __init__(self,inShape,inField,outValidation,outTrain,number=50,percent=True):
"""
inShape : str path file (e.g. '/doc/ref.shp')
inField : string column name (e.g. 'class')
outValidation : str path of shp output file (e.g. '/tmp/valid.shp')
outTrain : str path of shp output file (e.g. '/tmp/train.shp')
"""
if percent:
number = number / 100.0
else:
number = int(number)
lyr = ogr.Open(inShape)
lyr1 = lyr.GetLayer()
FIDs= sp.zeros(lyr1.GetFeatureCount(),dtype=int)
Features = []
#unselFeat = []
#current = 0
for i,j in enumerate(lyr1):
#print j.GetField(inField)
FIDs[i] = j.GetField(inField)
Features.append(j)
#current += 1
srs = lyr1.GetSpatialRef()
lyr1.ResetReading()
##
if percent:
validation,train = train_test_split(Features,test_size=number,train_size=1-number,stratify=FIDs)
else:
validation,train = train_test_split(Features,test_size=number,stratify=FIDs)
self.saveToShape(validation,srs,outValidation)
self.saveToShape(train,srs,outTrain)
评论列表
文章目录