vgg16.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:neural-style 作者: ctliu3 项目源码 文件源码
def forward(self, x, typ):
        if typ == ForwardType.Content:
            is_style, is_content = False, True
        elif typ == ForwardType.Style:
            is_style, is_content = True, False
        elif typ == ForwardType.Train:
            is_style, is_content = True, True
        else:
            raise Exception('Unknown forward type, {}'.format(typ))

        internals = {}

        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        if is_style:
            internals['conv1_2'] = x
        x = F.max_pool2d(x, 2, stride=2)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        if is_style or is_content:
            internals['conv2_2'] = x
        x = F.max_pool2d(x, 2, stride=2)

        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        if is_style:
            internals['conv3_3'] = x
        x = F.max_pool2d(x, 2, stride=2)

        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        if is_style:
            internals['conv4_3'] = x

        return internals
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号