def get_model(
model_file, model_name, loss_file, loss_name, class_weight, n_encdec,
n_classes, in_channel, n_mid, train_depth=None, result_dir=None):
model = imp.load_source(model_name, model_file)
model = getattr(model, model_name)
loss = imp.load_source(loss_name, loss_file)
loss = getattr(loss, loss_name)
# Initialize
model = model(n_encdec, n_classes, in_channel, n_mid)
if train_depth:
model = loss(model, class_weight, train_depth)
# Copy files
if result_dir is not None:
base_fn = os.path.basename(model_file)
dst = '{}/{}'.format(result_dir, base_fn)
if not os.path.exists(dst):
shutil.copy(model_file, dst)
base_fn = os.path.basename(loss_file)
dst = '{}/{}'.format(result_dir, base_fn)
if not os.path.exists(dst):
shutil.copy(loss_file, dst)
return model
评论列表
文章目录