def dap_deploy(m, x, labels, data, att_crit=None):
"""
Deploy DAP
:param m:
:param x:
:param labels:
:param data:
:param att_crit:
:return: Pandas series
"""
res = m(x)
if res.embed_pred is not None:
embed_logits = res.embed_pred @ data.attributes.embeds.t()
att_probs = [torch.sigmoid(embed_logits)]
else:
att_probs = []
# Start off with the embedding probabilities
if res.att_pred is None:
domains = []
else:
domains = att_crit.domains_per_att
start_col = 0
for gt_col, d_size in enumerate(domains):
# Get the attributes per verb
atts_by_verb = data.attributes.atts_matrix[:, gt_col]
if d_size == 1:
# Get the right indexing by taking the outer product between the
# [batch_size] attributes \in {+1, -1} and the logits
# This gives us a [batch_size x num_labels] matrix.
raw_ap = torch.ger(
res.att_pred[:, start_col],
2*(atts_by_verb.float() - 0.5),
)
att_probs.append(torch.sigmoid(raw_ap))
else:
# [batch_size x attribute domain_size] matrix
ap = F.softmax(res.att_pred[:, start_col:(start_col+d_size)])
#[batch_size x num_labels]
prob_contrib_by_label = torch.index_select(ap, 1, atts_by_verb)
att_probs.append(prob_contrib_by_label)
start_col += d_size
#[batch_size x num labels x num attributes]
probs_by_att = torch.stack(att_probs, 2)
# [batch_size, range size]
probs_prod = torch.prod(probs_by_att + 1e-12, 2).squeeze(2)
denom = probs_prod.sum(1) # [batch_size, 1]
probs = probs_prod / denom.expand_as(probs_prod)
return probs
###
评论列表
文章目录