def generate_model_file(num_rows, num_cols, num_cats=4, rate=1.0):
"""Generate a random model.
Returns:
The path to a gzipped pickled model.
"""
path = os.path.join(DATA, '{}-{}-{}-{:0.1f}.model.pkz'.format(
num_rows, num_cols, num_cats, rate))
V = num_cols
K = V * (V - 1) // 2
if os.path.exists(path):
return path
print('Generating {}'.format(path))
if not os.path.exists(DATA):
os.makedirs(DATA)
dataset_path = generate_dataset_file(num_rows, num_cols, num_cats, rate)
dataset = pickle_load(dataset_path)
table = dataset['table']
tree_prior = np.zeros(K, dtype=np.float32)
config = make_config(learning_init_epochs=5)
model = train_model(table, tree_prior, config)
pickle_dump(model, path)
return path
评论列表
文章目录