def loss_style(self, gen_img_rep):
feat_cor_gen = self.feature_cor(gen_img_rep)
feat_loss = 0
for i in range(len(feat_cor_gen)):
orig_shape = self.style_rep[i].shape
feat_map_size = orig_shape[2] * orig_shape[3] # M_l
layer_wt = 4.0 * feat_map_size ** 2.0
feat_loss += F.mean_squared_error(self.style_feat_cor[i], feat_cor_gen[i]) / layer_wt
return feat_loss
# total loss function
# cf. equation (7) of the article
评论列表
文章目录