def __init__(self, model):
'''
Setup directories, dataset, model, and optimizer
'''
self.batch_size = FLAGS.batch_size
self.iterations = FLAGS.iterations
self.learning_rate = FLAGS.learning_rate
self.model_dir = FLAGS.model_dir # directory to write model summaries to
self.dataset_dir = FLAGS.dataset_dir # directory containing data
self.samples_dir = FLAGS.samples_dir # directory for sampled images
self.device_id = FLAGS.device_id
self.use_gpu = FLAGS.use_gpu
# create directories if they don"t exist yert
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
if not os.path.exists(self.samples_dir):
os.makedirs(self.samples_dir)
if self.use_gpu:
device_str = '/gpu:' + str(self.device_id)
else:
device_str = '/cpu:0'
with tf.device(device_str):
self.global_step = tf.get_variable("global_step", [],
initializer=tf.constant_initializer(0), trainable=False)
# parse data and create model
self.dataset = Dataset(self.dataset_dir, self.iterations, self.batch_size)
self.model = model(self.dataset.hr_images, self.dataset.lr_images)
learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
500000, 0.5, staircase=True)
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=0.95, momentum=0.9, epsilon=1e-8)
self.train_optimizer = optimizer.minimize(self.model.loss, global_step=self.global_step)
trainer.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录