def main():
''' NOTE: The input is rescaled to [-1, 1] '''
dirs = validate_log_dirs(args)
tf.gfile.MakeDirs(dirs['logdir'])
with open(args.architecture) as f:
arch = json.load(f)
with open(os.path.join(dirs['logdir'], args.architecture), 'w') as f:
json.dump(arch, f, indent=4)
normalizer = Tanhize(
xmax=np.fromfile('./etc/xmax.npf'),
xmin=np.fromfile('./etc/xmin.npf'),
)
image, label = read(
file_pattern=arch['training']['datadir'],
batch_size=arch['training']['batch_size'],
capacity=2048,
min_after_dequeue=1024,
normalizer=normalizer,
)
machine = MODEL(arch)
loss = machine.loss(image, label)
trainer = TRAINER(loss, arch, args, dirs)
trainer.train(nIter=arch['training']['max_iter'], machine=machine)
评论列表
文章目录