test_image_iterators.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:ml-tools 作者: triagemd 项目源码 文件源码
def test_multi_directory_iterator_race_condition(sample_dataset_dir):
    n_models = 2
    batch_size = 4
    train_path = os.path.join(sample_dataset_dir, 'Training')
    val_path = os.path.join(sample_dataset_dir, 'Validation')

    # set up training and validation generators
    train_gen = MultiDirectoryIterator([make_dir_iterator(train_path, batch_size) for _ in range(n_models)])
    val_gen = MultiDirectoryIterator([make_dir_iterator(val_path, batch_size) for _ in range(n_models)])

    # join some MobileNets

    base_models = []
    for i in range(n_models):
        model = MobileNet(weights=None)
        for layer in model.layers:
            layer.name += str(i)
        base_models.append(model)

    x = concatenate([m.output for m in base_models])
    x = Dense(create_class_histogram(train_path).shape[0], name='dense')(x)
    x = Activation('softmax', name='act_softmax')(x)

    joined_model = Model([m.input for m in base_models], x)

    # run a few epochs

    joined_model.compile(optimizer=optimizers.SGD(), loss='categorical_crossentropy')

    joined_model.fit_generator(train_gen, validation_data=val_gen, epochs=4, workers=16,
                               steps_per_epoch=int(np.ceil(train_gen.samples / batch_size)),
                               validation_steps=int(np.ceil(val_gen.samples / batch_size)))

    # intentionally no assert, test passes if nothing throws
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号