neural_style.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainer-neural-style 作者: dsanno 项目源码 文件源码
def __fit(self, content_image, style_image, epoch_num, callback=None):
        xp = self.xp
        input_image = None
        height, width = content_image.shape[-2:]
        base_epoch = 0
        old_link = None
        for stride in [4, 2, 1][-self.resolution_num:]:
            if width // stride < 64:
                continue
            content_x = xp.asarray(content_image[:,:,::stride,::stride])
            if self.keep_color:
                style_x = util.luminance_only(xp.asarray(style_image[:,:,::stride,::stride]), content_x)
            else:
                style_x = xp.asarray(style_image[:,:,::stride,::stride])
            content_layer_names = self.content_layer_names
            with chainer.using_config('enable_backprop', False):
                content_layers = self.model(content_x)
            content_layers = [(name, content_layers[name]) for name in content_layer_names]
            style_layer_names = self.style_layer_names
            with chainer.using_config('enable_backprop', False):
                style_layers = self.model(style_x)
            style_grams = [(name, util.gram_matrix(style_layers[name])) for name in style_layer_names]
            if input_image is None:
                if self.initial_image == 'content':
                    input_image = xp.asarray(content_image[:,:,::stride,::stride])
                else:
                    input_image = xp.random.normal(0, 1, size=content_x.shape).astype(np.float32) * 0.001
            else:
                input_image = input_image.repeat(2, 2).repeat(2, 3)
                h, w = content_x.shape[-2:]
                input_image = input_image[:,:,:h,:w]
            link = chainer.Link(x=input_image.shape)
            if self.device_id >= 0:
                link.to_gpu()
            link.x.data[:] = xp.asarray(input_image)
            self.optimizer.setup(link)
            for epoch in six.moves.range(epoch_num):
                loss_info = self.__fit_one(link, content_layers, style_grams)
                if callback:
                    callback(base_epoch + epoch, link.x, loss_info)
            base_epoch += epoch_num
            input_image = link.x.data
        return link.x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号