def batch_train(opt, round_index, round_train_data, round_valid_data, round_valid_weights=None, save_all=True, file_indices=None, return_acc_len=False, seq2seq=False):
i = 0
perfs = []
M = len(round_train_data)
while i < M:
j = min(i + opt['num_machines'], M)
cur_perfs = Parallel(n_jobs=j - i, backend='threading') \
(delayed(train)(opt, round_index, train_index, file_indices[train_index] if file_indices else train_index, round_train_data[train_index], round_valid_data[train_index], valid_weights=round_valid_weights[train_index] if round_valid_weights else None, save_all=save_all, return_acc_len=return_acc_len, seq2seq=seq2seq) \
for train_index in range(i, j))
perfs.extend(cur_perfs)
i = j
error_indices, valid_indices = [], []
for i, perf in enumerate(perfs):
if perf == 0.0 or type(perf) == tuple and perf[0] == 0.0:
error_indices.append(i)
elif i < opt['num_machines']:
valid_indices.append(i)
M = len(error_indices)
TMP_NUM_MACHINES = len(valid_indices)
if M > 0 and TMP_NUM_MACHINES > 0:
i = 0
error_perfs = []
while i < M:
j = min(i + TMP_NUM_MACHINES, M)
cur_perfs = Parallel(n_jobs=j - i, backend='threading') \
(delayed(train)(opt, round_index, valid_indices[train_index], file_indices[error_indices[train_index]] if file_indices else error_indices[train_index], round_train_data[error_indices[train_index]], round_valid_data[error_indices[train_index]], valid_weights=round_valid_weights[error_indices[train_index]] if round_valid_weights else None, save_all=save_all, return_acc_len=return_acc_len, seq2seq=seq2seq) \
for train_index in range(i, j))
error_perfs.extend(cur_perfs)
i = j
for i in range(M):
perfs[error_indices[i]] = error_perfs[i]
return perfs
评论列表
文章目录