def build_model(model, wrapper, dataset, hyperparams, reweighting):
def build_optimizer(opt, hyperparams):
return {
"sgd": SGD(
lr=hyperparams.get("lr", 0.001),
momentum=hyperparams.get("momentum", 0.0)
),
"adam": Adam(lr=hyperparams.get("lr", 0.001))
}[opt]
model = models.get(model)(dataset.shape, dataset.output_size)
model.compile(
optimizer=build_optimizer(
hyperparams.get("opt", "adam"),
hyperparams
),
loss=model.loss,
metrics=model.metrics
)
return get_models_dictionary(hyperparams, reweighting)[wrapper](model)
评论列表
文章目录