run.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def batch_train(opt, round_index, round_train_data, round_valid_data, round_valid_weights=None, save_all=True, file_indices=None, return_acc_len=False, seq2seq=False):
    i = 0
    perfs = []
    M = len(round_train_data)
    while i < M:
        j = min(i + opt['num_machines'], M)
        cur_perfs = Parallel(n_jobs=j - i, backend='threading') \
            (delayed(train)(opt, round_index, train_index, file_indices[train_index] if file_indices else train_index, round_train_data[train_index], round_valid_data[train_index], valid_weights=round_valid_weights[train_index] if round_valid_weights else None, save_all=save_all, return_acc_len=return_acc_len, seq2seq=seq2seq) \
                for train_index in range(i, j))
        perfs.extend(cur_perfs)
        i = j

    error_indices, valid_indices = [], []
    for i, perf in enumerate(perfs):
        if perf == 0.0 or type(perf) == tuple and perf[0] == 0.0:
            error_indices.append(i)
        elif i < opt['num_machines']:
            valid_indices.append(i)

    M = len(error_indices)
    TMP_NUM_MACHINES = len(valid_indices)
    if M > 0 and TMP_NUM_MACHINES > 0:
        i = 0
        error_perfs = []
        while i < M:
            j = min(i + TMP_NUM_MACHINES, M)
            cur_perfs = Parallel(n_jobs=j - i, backend='threading') \
                (delayed(train)(opt, round_index, valid_indices[train_index], file_indices[error_indices[train_index]] if file_indices else error_indices[train_index], round_train_data[error_indices[train_index]], round_valid_data[error_indices[train_index]], valid_weights=round_valid_weights[error_indices[train_index]] if round_valid_weights else None, save_all=save_all, return_acc_len=return_acc_len, seq2seq=seq2seq) \
                    for train_index in range(i, j))
            error_perfs.extend(cur_perfs)
            i = j
        for i in range(M):
            perfs[error_indices[i]] = error_perfs[i]

    return perfs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号