def __fit(self, content_image, style_image, epoch_num, callback=None):
xp = self.xp
input_image = None
height, width = content_image.shape[-2:]
base_epoch = 0
old_link = None
for stride in [4, 2, 1][-self.resolution_num:]:
if width // stride < 64:
continue
content_x = xp.asarray(content_image[:,:,::stride,::stride])
if self.keep_color:
style_x = util.luminance_only(xp.asarray(style_image[:,:,::stride,::stride]), content_x)
else:
style_x = xp.asarray(style_image[:,:,::stride,::stride])
content_layer_names = self.content_layer_names
with chainer.using_config('enable_backprop', False):
content_layers = self.model(content_x)
content_layers = [(name, content_layers[name]) for name in content_layer_names]
style_layer_names = self.style_layer_names
with chainer.using_config('enable_backprop', False):
style_layers = self.model(style_x)
style_grams = [(name, util.gram_matrix(style_layers[name])) for name in style_layer_names]
if input_image is None:
if self.initial_image == 'content':
input_image = xp.asarray(content_image[:,:,::stride,::stride])
else:
input_image = xp.random.normal(0, 1, size=content_x.shape).astype(np.float32) * 0.001
else:
input_image = input_image.repeat(2, 2).repeat(2, 3)
h, w = content_x.shape[-2:]
input_image = input_image[:,:,:h,:w]
link = chainer.Link(x=input_image.shape)
if self.device_id >= 0:
link.to_gpu()
link.x.data[:] = xp.asarray(input_image)
self.optimizer.setup(link)
for epoch in six.moves.range(epoch_num):
loss_info = self.__fit_one(link, content_layers, style_grams)
if callback:
callback(base_epoch + epoch, link.x, loss_info)
base_epoch += epoch_num
input_image = link.x.data
return link.x
评论列表
文章目录