def test_experiment_fit_gen_async(self, get_model, get_loss_metric,
get_custom_l):
new_session()
model, metrics, cust_objects = prepare_model(get_model(get_custom_l),
get_loss_metric,
get_custom_l)
_, data_val_use = make_data(train_samples, test_samples)
expe = Experiment(model)
expected_value = 2
for val in [None, 1, data_val_use]:
gen, data, data_stream = make_gen(batch_size)
if val == 1:
val, data_2, data_stream_2 = make_gen(batch_size)
_, thread = expe.fit_gen_async([gen], [val], nb_epoch=2,
model=model,
metrics=metrics,
custom_objects=cust_objects,
samples_per_epoch=64,
nb_val_samples=128,
verbose=2, overwrite=True)
thread.join()
for k in expe.full_res['metrics']:
if 'iter' not in k:
assert len(
expe.full_res['metrics'][k]) == expected_value
close_gens(gen, data, data_stream)
if val == 1:
close_gens(val, data_2, data_stream_2)
if K.backend() == 'tensorflow':
K.clear_session()
print(self)
评论列表
文章目录