classifier.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:siam 作者: btlk 项目源码 文件源码
def train_classifier(X_train, y_train, 
  X_test, y_test, model_path, output_path, epochs):
  image_shape = X_train[0].shape
  model = create_base_network(X_train.shape[1:])
  model.load_weights(model_path)
  model = attach_classifier(model, 2)

  opt = opts.RMSprop(epsilon = 1e-4, decay = 1e-6)
  model.compile(
    loss = 'categorical_crossentropy', 
    metrics = [acc],
    optimizer = opt
  )

  callbacks_list = []
  if output_path is not None:
    if os.path.exists(output_path):
      shutil.rmtree(output_path)
    os.makedirs(output_path)
    file_fmt = '{epoch:02d}-{loss:.4f}-{val_loss:.4f}.hdf5'
    checkpoint = SeparateSaveCallback(
      output_path, file_fmt, siamese = False)
    callbacks_list = [checkpoint]

  y_train = to_categorical(y_train, 2)
  if X_test is not None and y_test is not None:
    y_test = to_categorical(y_test, 2)
    history = model.fit(
      x = X_train, 
      y = y_train,
      callbacks = callbacks_list,
      batch_size = batch_size,
      validation_data = (X_test, y_test),
      epochs = epochs
    )

  else:
    history = model.fit(
      x = X_train, 
      y = y_train,
      callbacks = callbacks_list,
      batch_size = batch_size,
      validation_split = 0.5,
      epochs = epochs
    )

  return history
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号