def test_EarlyStopping_reuse():
patience = 3
data = np.random.random((100, 1))
labels = np.where(data > 0.5, 1, 0)
model = Sequential((
Dense(1, input_dim=1, activation='relu'),
Dense(1, activation='sigmoid'),
))
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
stopper = callbacks.EarlyStopping(monitor='acc', patience=patience)
weights = model.get_weights()
hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience
# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience
评论列表
文章目录