def train_classifier(X_train, y_train,
X_test, y_test, model_path, output_path, epochs):
image_shape = X_train[0].shape
model = create_base_network(X_train.shape[1:])
model.load_weights(model_path)
model = attach_classifier(model, 2)
opt = opts.RMSprop(epsilon = 1e-4, decay = 1e-6)
model.compile(
loss = 'categorical_crossentropy',
metrics = [acc],
optimizer = opt
)
callbacks_list = []
if output_path is not None:
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.makedirs(output_path)
file_fmt = '{epoch:02d}-{loss:.4f}-{val_loss:.4f}.hdf5'
checkpoint = SeparateSaveCallback(
output_path, file_fmt, siamese = False)
callbacks_list = [checkpoint]
y_train = to_categorical(y_train, 2)
if X_test is not None and y_test is not None:
y_test = to_categorical(y_test, 2)
history = model.fit(
x = X_train,
y = y_train,
callbacks = callbacks_list,
batch_size = batch_size,
validation_data = (X_test, y_test),
epochs = epochs
)
else:
history = model.fit(
x = X_train,
y = y_train,
callbacks = callbacks_list,
batch_size = batch_size,
validation_split = 0.5,
epochs = epochs
)
return history
评论列表
文章目录