def _build(self,flags,files):
path = flags.input_path
Table = namedtuple('Table', 'name fname dtype')
fnames = "adult.data,adult.test".split(',')
names = "train,test".split(',')
TABLES = [Table(i,"%s/%s"%(path,j),None) for i,j in zip(names,fnames) if files =="all" or i in files]
print()
self.flags = flags
path = flags.data_path
data = {}
columns = [
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"
]
for table in TABLES:
name = table.name
fname = table.fname
dtype = table.dtype
pname = "%s/%s.pkl"%(path,name.split('/')[-1].split('.')[0])
if os.path.exists(pname):
data[name] = pd.read_pickle(pname)
else:
if name == 'train':
data[name] = pd.read_csv(fname,dtype=dtype,header=None,skipinitialspace=True,
names=columns)
if name == 'test':
data[name] = pd.read_csv(fname,dtype=dtype,header=None,skipinitialspace=True,
skiprows=1,names=columns)
data[name]['target'] = data[name]["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
data[name].drop('income_bracket',axis=1,inplace=True)
data[name].to_pickle(pname)
print_mem_time("Loaded {} {}".format(fname.split('/')[-1],data[name].shape))
self.data = data # no copy, pass the inference
print()
评论列表
文章目录