def __init__(self, embed_ndim):
super(VisualSemanticEmbedding, self).__init__()
self.embed_ndim = embed_ndim
# image feature
self.img_encoder = models.vgg16(pretrained=True)
for param in self.img_encoder.parameters():
param.requires_grad = False
self.feat_extractor = nn.Sequential(*(self.img_encoder.classifier[i] for i in range(6)))
self.W = nn.Linear(4096, embed_ndim, False)
# text feature
self.txt_encoder = nn.GRU(embed_ndim, embed_ndim, 1)
评论列表
文章目录