train.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def _build_graph(self, image_size):

        self.image_size = image_size
        self.images = tf.placeholder(tf.float32,
                                     shape = (None, image_size, image_size, 3))
        images_mini = tf.image.resize_images(self.images,
                                             size = (int(image_size/4),
                                                     int(image_size/4)))
        self.images_blur = tf.image.resize_images(images_mini,
                                                  size = (image_size, image_size))

        self.net = U_Net(output_ch = 3, block_fn = 'origin')
        self.images_reconst = self.net(self.images_blur, reuse = False)
        # self.image_reconst can be [-inf +inf], so need to clip its value if visualize them as images.
        self.loss = tf.reduce_mean((self.images_reconst - self.images)**2)
        self.opt = tf.train.AdamOptimizer()\
                           .minimize(self.loss, var_list = self.net.vars)

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号