def load_data(self, train_dfn="adult.data", test_dfn="adult.test"):
'''
Load data (use files offered in the Tensorflow wide_n_deep_tutorial)
'''
if not os.path.exists(train_dfn):
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_dfn)
print("Training data is downloaded to %s" % train_dfn)
if not os.path.exists(test_dfn):
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_dfn)
print("Test data is downloaded to %s" % test_dfn)
self.train_data = pd.read_csv(train_dfn, names=COLUMNS, skipinitialspace=True)
self.test_data = pd.read_csv(test_dfn, names=COLUMNS, skipinitialspace=True, skiprows=1)
self.train_data[self.label_column] = (self.train_data["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
self.test_data[self.label_column] = (self.test_data["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
评论列表
文章目录