def __call__(self, trainer):
print('## Calculate BLEU')
with chainer.no_backprop_mode():
with chainer.using_config('train', False):
references = []
hypotheses = []
for i in range(0, len(self.test_data), self.batch):
sources, targets = zip(*self.test_data[i:i + self.batch])
references.extend([[t.tolist()] for t in targets])
sources = [
chainer.dataset.to_device(self.device, x) for x in sources]
ys = [y.tolist()
for y in self.model.translate(sources, self.max_length)]
hypotheses.extend(ys)
bleu = bleu_score.corpus_bleu(
references, hypotheses,
smoothing_function=bleu_score.SmoothingFunction().method1) * 100
print('BLEU:', bleu)
reporter.report({self.key: bleu})
python类dataset()的实例源码
def is_dataset(obj):
"""Check if obj is Chainer dataset instance or not"""
return isinstance(obj, (DictDataset, ImageDataset, LabeledImageDataset,
TupleDataset, DatasetMixin))
def _check_X_y(self, X, y=None):
"""Check type of X and y.
It updates the format of X and y (such as dtype, convert sparse matrix
to matrix format etc) if necessary.
`X` and `y` might be array (numpy.ndarray or sparse matrix) for sklearn
interface, but `X` might be chainer dataset.
:param X: chainer dataset type or array
:param y: None or array
:return:
"""
return X, y
def fit(self, X, y=None, **kwargs):
"""If hyper parameters are set to None, then instance's variable is used,
this functionality is used Grid search with `set_params` method.
Also if instance's variable is not set, _default_hyperparam is used.
Usage: model.fit(train_dataset) or model.fit(X, y)
Args:
train: training dataset, assumes chainer's dataset class
test: test dataset for evaluation, assumes chainer's dataset class
batchsize: batchsize for both training and evaluation
iterator_class: iterator class used for this training,
currently assumes SerialIterator or MultiProcessIterator
optimizer: optimizer instance to update parameter
epoch: training epoch
out: directory path to save the result
snapshot_frequency (int): snapshot frequency in epoch.
Negative value indicates not to take snapshot.
dump_graph: Save computational graph info or not, default is False.
log_report: Enable LogReport or not
plot_report: Enable PlotReport or not
print_report: Enable PrintReport or not
progress_report: Enable ProgressReport or not
resume: specify trainer saved path to resume training.
"""
kwargs = self.filter_sk_params(self.fit_core, kwargs)
return self.fit_core(X, y, **kwargs)
def _check_X_y(self, X, y=None):
#print('check_X_y', type(X), type(y))
if not is_dataset(X) and not isinstance(X, list):
if isinstance(X, numpy.ndarray):
X = check_array(X, dtype=self._data_x_dtype)
else:
print('[WARNING] skip check type for dataset X with type {}'
.format(type(X)))
if y is not None:
y = check_array(y, dtype=self._data_y_dtype, ensure_2d=False)
return X, y
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=50,
help='Number of units')
parser.add_argument('--example', '-ex', type=int, default=3,
help='Example mode')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu)
if args.example == 1:
print("Example 1. fit with x, y numpy array (same with sklearn's fit)")
x, y = concat_examples(train)
model.fit(x, y)
elif args.example == 2:
print("Example 2. Train with Chainer's dataset")
# `train` is TupleDataset in this example
# Even this one line work! (but no validation)
model.fit(train)
else:
print("Example 3. Train with configuration")
model.fit(
train,
test=test,
batchsize=args.batchsize,
#iterator_class=chainer.iterators.SerialIterator,
optimizer=chainer.optimizers.Adam(),
epoch=args.epoch,
out=args.out,
snapshot_frequency=1,
#dump_graph=False
#log_report=True,
plot_report=False,
#print_report=True,
progress_report=False,
resume=args.resume
)
# Save trained model
serializers.save_npz('{}/mlp.model'.format(args.out), model)