def forward(self, x, typ):
if typ == ForwardType.Content:
is_style, is_content = False, True
elif typ == ForwardType.Style:
is_style, is_content = True, False
elif typ == ForwardType.Train:
is_style, is_content = True, True
else:
raise Exception('Unknown forward type, {}'.format(typ))
internals = {}
x = F.relu(self.conv1_1(x))
x = F.relu(self.conv1_2(x))
if is_style:
internals['conv1_2'] = x
x = F.max_pool2d(x, 2, stride=2)
x = F.relu(self.conv2_1(x))
x = F.relu(self.conv2_2(x))
if is_style or is_content:
internals['conv2_2'] = x
x = F.max_pool2d(x, 2, stride=2)
x = F.relu(self.conv3_1(x))
x = F.relu(self.conv3_2(x))
x = F.relu(self.conv3_3(x))
if is_style:
internals['conv3_3'] = x
x = F.max_pool2d(x, 2, stride=2)
x = F.relu(self.conv4_1(x))
x = F.relu(self.conv4_2(x))
x = F.relu(self.conv4_3(x))
if is_style:
internals['conv4_3'] = x
return internals
评论列表
文章目录