def get_mnist():
train, test = chainer.datasets.get_mnist(ndim=3)
train_data = [t for t in train]
test_data = [t for t in test]
train_data = np.array(train_data)
test_data = np.array(test_data)
train_data = np.expand_dims(train_data, 1)
test_data = np.expand_dims(test_data, 1)
train_xs = train_data[:,:,0].T
train_ys = train_data[:,:,1].T
test_xs = test_data[:,:,0].T
test_ys = test_data[:,:,1].T
train = TupleDataset(*(train_xs.tolist() + train_ys.tolist()))
test = TupleDataset(*(test_xs.tolist() + test_ys.tolist()))
return train,test
python类datasets()的实例源码
def score_core(self, X, y=None, sample_weight=None, batchsize=16):
# Type check
X, y = self._check_X_y(X, y)
# during GridSearch, which only assumes score(X, y) interface.
if y is None:
test = X
if isinstance(test, numpy.ndarray): # TODO: reivew
print('score_core numpy.ndarray received...')
test = chainer.datasets.TupleDataset(test)
else:
test = chainer.datasets.TupleDataset(X, y)
# For Classifier
# `accuracy` is calculated as score, using `forward_batch`
# For regressor
# `loss` is calculated as score, using `forward_batch`
self.forward_batch(test, batchsize=batchsize, retain_inputs=False, calc_score=True)
return self.total_score
def main(args):
# get datasets
source_train, source_test = chainer.datasets.get_svhn()
target_train, target_test = chainer.datasets.get_mnist(ndim=3, rgb_format=True)
source = source_train, source_test
# resize mnist to 32x32
def transform(in_data):
img, label = in_data
img = resize(img, (32, 32))
return img, label
target_train = TransformDataset(target_train, transform)
target_test = TransformDataset(target_test, transform)
target = target_train, target_test
# load pretrained source, or perform pretraining
pretrained = os.path.join(args.output, args.pretrained_source)
if not os.path.isfile(pretrained):
source_cnn = pretrain_source_cnn(source, args)
else:
source_cnn = Loss(num_classes=10)
serializers.load_npz(pretrained, source_cnn)
# how well does this perform on target domain?
test_pretrained_on_target(source_cnn, target, args)
# initialize the target cnn (do not use source_cnn.copy)
target_cnn = Loss(num_classes=10)
# copy parameters from source cnn to target cnn
target_cnn.copyparams(source_cnn)
train_target_cnn(source, target, source_cnn, target_cnn, args)