test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DRPN 作者: w7829 项目源码 文件源码
def Data_iterate_minibatches(inputs, targets, batchsize, arg=False, genSetting=None, shuffle=False, warpMode=None):
    # assert len(inputs[0]) == len(targets[0])
    if shuffle:
        rinputs = copy.deepcopy(inputs)
        rtargets = copy.deepcopy(targets)
        indices = np.random.permutation(len(inputs[0]))
        for i in range(len(inputs[0])):
            for idx in range(len(inputs)):
                rinputs[idx][i] = inputs[idx][indices[i]]
            for idx in range(len(targets)):
                rtargets[idx][i] = targets[idx][indices[i]]
        inputs = rinputs
        targets = rtargets
        # inputs[:] = inputs[indices]
        # targets[:] = targets[indices]

    init = True
    global input_tmp
    global target_tmp
    global isOK
    for start_idx in range(0, len(inputs[0]) - batchsize*2 + 1, batchsize):
        # if (isOK == False) and (two == False):
        #     inputsbatch, targetsbatch = read_pics(inputs[start_idx:start_idx + batchsize], targets[start_idx:start_idx + batchsize], batchsize, crop, mirror, flip, rotate)
        # else:
        while isOK == False:
            if init:
                sl = range(start_idx,start_idx + batchsize)
                thread.start_new_thread(Data_readPics_thread, ([itemgetter(*sl)(i) for i in inputs], [itemgetter(*sl)(i) for i in targets], batchsize, genSetting, arg, warpMode))
                init = False
                # inputsbatch, targetsbatch = read_pics(inputs[start_idx:start_idx + batchsize], targets[start_idx:start_idx + batchsize], batchsize, crop, mirror, flip, rotate)
            time.sleep(0.01)
        inputsbatch, targetsbatch = input_tmp, target_tmp
        isOK = False
        sl = range(start_idx  + batchsize,start_idx + 2 * batchsize)
        thread.start_new_thread(Data_readPics_thread, ([itemgetter(*sl)(i) for i in inputs], [itemgetter(*sl)(i) for i in targets], batchsize, genSetting, arg, warpMode))
        # yield itertools.chain(inputsbatch, targetsbatch)
        yield inputsbatch + targetsbatch
    while isOK == False:
        time.sleep(0.01)
    inputsbatch, targetsbatch = input_tmp, target_tmp
    isOK = False
    # yield itertools.chain(inputsbatch, targetsbatch)
    yield inputsbatch + targetsbatch
    # len(inputs) - batchsize*2 + 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号