def main(_):
model_dir = util.get_model_dir(conf,
['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step',
'is_train', 'random_seed', 'log_level', 'display', 'runtime_base_dir',
'occlude_start_row', 'num_generated_images'])
util.preprocess_conf(conf)
validate_parameters(conf)
data = 'mnist' if conf.data == 'color-mnist' else conf.data
DATA_DIR = os.path.join(conf.runtime_base_dir, conf.data_dir, data)
SAMPLE_DIR = os.path.join(conf.runtime_base_dir, conf.sample_dir, conf.data, model_dir)
util.check_and_create_dir(DATA_DIR)
util.check_and_create_dir(SAMPLE_DIR)
dataset = get_dataset(DATA_DIR, conf.q_levels)
with tf.Session() as sess:
network = Network(sess, conf, dataset.height, dataset.width, dataset.channels)
stat = Statistic(sess, conf.data, conf.runtime_base_dir, model_dir, tf.trainable_variables())
stat.load_model()
if conf.is_train:
train(dataset, network, stat, SAMPLE_DIR)
else:
generate(network, dataset.height, dataset.width, SAMPLE_DIR)
评论列表
文章目录