def __init__(self, zeroshot, embed_dim=None, att_domains=None, num_train_classes=None, l2_weight=None):
"""
:param zeroshot: Whether we're running in zeroshot mode (
can be true or False).
:param embed_dim: Dimension of embeddings (probably 300)
:param att_dims: List of domain sizes per attribute.
:param num_train_classes: If we're doing pretraining, number of classes to use
"""
super(ImsituModel, self).__init__()
self.l2_weight = l2_weight
if zeroshot:
if (embed_dim is not None) and (att_domains is not None):
print("Using embeddings and attributes for zeroshot")
elif embed_dim is not None:
print("Using embeddings for zeroshot")
elif att_domains is not None:
print("using attributes for zeroshot")
else:
raise ValueError("Must supply embeddings or attributes for zeroshot")
self.fc_dim = None
self.att_domains = att_domains if att_domains is not None else []
self.embed_dim = embed_dim
else:
if num_train_classes is None:
raise ValueError("Must supply a # of training classes")
self.fc_dim = num_train_classes
self.att_domains = []
self.embed_dim = None
self.resnet152 = get_pretrained_resnet(self.fc_dim)
if self.embed_dim is not None:
self.embed_linear = nn.Linear(ENCODING_SIZE, self.embed_dim)
_init_fc(self.embed_linear)
if self.att_dim is not None:
self.att_linear = nn.Linear(ENCODING_SIZE, self.att_dim)
_init_fc(self.att_linear)
评论列表
文章目录