train_utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TensorBox 作者: Russell91 项目源码 文件源码
def load_idl_tf(idlfile, H, jitter):
    """Take the idlfile and net configuration and create a generator
    that outputs a jittered version of a random image from the annolist
    that is mean corrected."""

    annolist = al.parse(idlfile)
    annos = []
    for anno in annolist:
        anno.imageName = os.path.join(
            os.path.dirname(os.path.realpath(idlfile)), anno.imageName)
        annos.append(anno)
    random.seed(0)
    if H['data']['truncate_data']:
        annos = annos[:10]
    for epoch in itertools.count():
        random.shuffle(annos)
        for anno in annos:
            try:
                if 'grayscale' in H and 'grayscale_prob' in H:
                    I = imread(anno.imageName, mode = 'RGB' if random.random() < H['grayscale_prob'] else 'L')
                    if len(I.shape) < 3:
                        I = cv2.cvtColor(I, cv2.COLOR_GRAY2RGB)
                else:
                    if len(I.shape) < 3:
                        continue
                    I = imread(anno.imageName, mode = 'RGB')
                if I.shape[0] != H["image_height"] or I.shape[1] != H["image_width"]:
                    if epoch == 0:
                        anno = rescale_boxes(I.shape, anno, H["image_height"], H["image_width"])
                    I = imresize(I, (H["image_height"], H["image_width"]), interp='cubic')
                if jitter:
                    jitter_scale_min=0.9
                    jitter_scale_max=1.1
                    jitter_offset=16
                    I, anno = annotation_jitter(I,
                                                anno, target_width=H["image_width"],
                                                target_height=H["image_height"],
                                                jitter_scale_min=jitter_scale_min,
                                                jitter_scale_max=jitter_scale_max,
                                                jitter_offset=jitter_offset)

                boxes, flags = annotation_to_h5(H,
                                                anno,
                                                H["grid_width"],
                                                H["grid_height"],
                                                H["rnn_len"])

                yield {"image": I, "boxes": boxes, "flags": flags}
            except Exception as exc:
                print(exc)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号