def enwik8_raw_data(data_path=None, num_test_symbols=5000000):
"""Load raw data from data directory "data_path".
The raw Hutter prize data is at:
http://mattmahoney.net/dc/enwik8.zip
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
num_test_symbols: number of symbols at the end that make up the test set
Returns:
tuple (train_data, valid_data, test_data, unique)
where each of the data objects can be passed to hutter_iterator.
"""
data_path = os.path.join(data_path, "enwik8")
raw_data = _read_symbols(data_path)
raw_data = np.fromstring(raw_data, dtype=np.uint8)
unique, data = np.unique(raw_data, return_inverse=True)
train_data = data[: -2 * num_test_symbols]
valid_data = data[-2 * num_test_symbols: -num_test_symbols]
test_data = data[-num_test_symbols:]
return train_data, valid_data, test_data, unique
评论列表
文章目录