def train(self, X_train, X_val):
train_true = filter(lambda x: x[2]==1, X_train)
train_false = filter(lambda x: x[2]==0, X_train)
val_true = filter(lambda x: x[2]==1, X_val)
val_false = filter(lambda x: x[2]==0, X_val)
n_train_true = len(train_true)
n_val_true = len(val_true)
make_epoch_helper = functools.partial(make_epoch, train_true=train_true, train_false=train_false, val_true=val_true, val_false=val_false)
logging.info("Starting training...")
epoch_iterator = ParallelBatchIterator(make_epoch_helper, range(P.N_EPOCHS), ordered=False, batch_size=1, multiprocess=False, n_producers=1)
for epoch_values in epoch_iterator:
self.pre_epoch()
train_epoch_data, val_epoch_data = epoch_values
train_epoch_data = util.chunks(train_epoch_data, P.BATCH_SIZE_TRAIN)
val_epoch_data = util.chunks(val_epoch_data, P.BATCH_SIZE_VALIDATION)
self.do_batches(self.train_fn, train_epoch_data, self.train_metrics)
self.do_batches(self.val_fn, val_epoch_data, self.val_metrics)
self.post_epoch()
logging.info("Setting learning rate to {}".format(P.LEARNING_RATE * ((0.985)**self.epoch)))
self.l_r.set_value(P.LEARNING_RATE * ((0.985)**self.epoch))
评论列表
文章目录