test_parallel.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ml-tools 作者: triagemd 项目源码 文件源码
def run_parallel_test(data_generator):
    a = Input(shape=(3,), name='input_a')
    b = Input(shape=(3,), name='input_b')
    a_2 = Dense(4, name='dense_1')(a)
    dp = Dropout(0.5, name='dropout')
    b_2 = dp(b)
    optimizer = 'rmsprop'
    loss = 'mse'
    loss_weights = [1., 0.5]
    model = Model([a, b], [a_2, b_2])
    model = make_parallel(model, 2)
    model.compile(optimizer, loss,
                  metrics=[],
                  loss_weights=loss_weights,
                  sample_weight_mode=None)

    trained_epochs = []
    tracker_cb = LambdaCallback(on_epoch_begin=lambda epoch, logs: trained_epochs.append(epoch))
    model.fit_generator(data_generator(4),
                        steps_per_epoch=3,
                        epochs=5,
                        initial_epoch=2,
                        callbacks=[tracker_cb])
    assert trained_epochs == [2, 3, 4]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号