trainer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Pixel-Recursive-Super-Resolution 作者: hodgka 项目源码 文件源码
def __init__(self, model):
        '''
        Setup directories, dataset, model, and optimizer
        '''
        self.batch_size = FLAGS.batch_size
        self.iterations = FLAGS.iterations
        self.learning_rate = FLAGS.learning_rate

        self.model_dir = FLAGS.model_dir  # directory to write model summaries to
        self.dataset_dir = FLAGS.dataset_dir  # directory containing data
        self.samples_dir = FLAGS.samples_dir  # directory for sampled images
        self.device_id = FLAGS.device_id
        self.use_gpu = FLAGS.use_gpu

        # create directories if they don"t exist yert
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.dataset_dir):
            os.makedirs(self.dataset_dir)
        if not os.path.exists(self.samples_dir):
            os.makedirs(self.samples_dir)

        if self.use_gpu:
            device_str = '/gpu:' + str(self.device_id)
        else:
            device_str = '/cpu:0'
        with tf.device(device_str):
            self.global_step = tf.get_variable("global_step", [],
                                               initializer=tf.constant_initializer(0), trainable=False)

            # parse data and create model
            self.dataset = Dataset(self.dataset_dir, self.iterations, self.batch_size)
            self.model = model(self.dataset.hr_images, self.dataset.lr_images)
            learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,
                                                       500000, 0.5,  staircase=True)
            optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=0.95, momentum=0.9, epsilon=1e-8)
            self.train_optimizer = optimizer.minimize(self.model.loss, global_step=self.global_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号