def test_fit(self, get_model):
"Test the training of a serialized model"
new_session()
data, data_val = make_data(train_samples, test_samples)
model = get_model()
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model_dict = dict()
model_dict['model_arch'] = to_dict_w_opt(model)
res = KTB.train(copy.deepcopy(model_dict['model_arch']), [data],
[data_val], [])
res = KTB.fit(NAME, VERSION, model_dict, [data], 'test', [data_val],
[])
assert len(res) == 4
if K.backend() == 'tensorflow':
K.clear_session()
print(self)
评论列表
文章目录