model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dong_iccv_2017 作者: woozzu 项目源码 文件源码
def __init__(self, embed_ndim):
        super(VisualSemanticEmbedding, self).__init__()
        self.embed_ndim = embed_ndim

        # image feature
        self.img_encoder = models.vgg16(pretrained=True)
        for param in self.img_encoder.parameters():
            param.requires_grad = False
        self.feat_extractor = nn.Sequential(*(self.img_encoder.classifier[i] for i in range(6)))
        self.W = nn.Linear(4096, embed_ndim, False)

        # text feature
        self.txt_encoder = nn.GRU(embed_ndim, embed_ndim, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号