def load_data(dataset):
if dataset.split('.')[-1] == 'gz':
f = gzip.open(dataset, 'r')
else:
f = open(dataset, 'r')
train_set, valid_set, test_set = pkl.load(f)
f.close()
def shared_dataset(data_xy, borrow=True):
data_x, data_y = data_xy
shared_x = theano.shared(
np.asarray(data_x, dtype=theano.config.floatX),
borrow=borrow)
shared_y = theano.shared(
np.asarray(data_y, dtype=theano.config.floatX),
borrow=borrow)
return shared_x, T.cast(shared_y, 'int32')
train_set_x, train_set_y = shared_dataset(train_set)
valid_set_x, valid_set_y = shared_dataset(valid_set)
test_set_x, test_set_y = shared_dataset(test_set)
return [(train_set_x, train_set_y),
(valid_set_x, valid_set_y),
(test_set_x, test_set_y )]
评论列表
文章目录