def execute(self,
dataset: Dataset,
execution_scripts,
train=False,
compute_losses=True,
summaries=True,
batch_size=None,
log_progress: int = 0) -> List[ExecutionResult]:
if batch_size is None:
batch_size = len(dataset)
batched_dataset = dataset.batch_dataset(batch_size)
last_log_time = time.process_time()
batch_results = [
[] for _ in execution_scripts] # type: List[List[ExecutionResult]]
for batch_id, batch in enumerate(batched_dataset):
if (time.process_time() - last_log_time > log_progress
and log_progress > 0):
log("Processed {} examples.".format(batch_id * batch_size))
last_log_time = time.process_time()
executables = [s.get_executable(compute_losses=compute_losses,
summaries=summaries,
num_sessions=len(self.sessions))
for s in execution_scripts]
while not all(ex.result is not None for ex in executables):
self._run_executables(batch, executables, train)
for script_list, executable in zip(batch_results, executables):
script_list.append(executable.result)
collected_results = [] # type: List[ExecutionResult]
for result_list in batch_results:
collected_results.append(reduce_execution_results(result_list))
return collected_results
评论列表
文章目录