def forward(self, x, lengths):
"""Handles variable size captions
"""
# Embed word ids to vectors
x = self.embed(x)
packed = pack_padded_sequence(x, lengths, batch_first=True)
# Forward propagate RNN
out, _ = self.rnn(packed)
# Reshape *final* output to (batch_size, hidden_size)
padded = pad_packed_sequence(out, batch_first=True)
I = torch.LongTensor(lengths).view(-1, 1, 1)
I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda()
out = torch.gather(padded[0], 1, I).squeeze(1)
# normalization in the joint embedding space
out = l2norm(out)
# take absolute value, used by order embeddings
if self.use_abs:
out = torch.abs(out)
return out
python类gather()的实例源码
def reverse_sequence(self, x, x_lens):
batch_size, seq_len, word_dim = x.size()
inv_idx = Variable(torch.arange(seq_len - 1, -1, -1).long())
shift_idx = Variable(torch.arange(0, seq_len).long())
if x.is_cuda:
inv_idx = inv_idx.cuda(x.get_device())
shift_idx = shift_idx.cuda(x.get_device())
inv_idx = inv_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx.unsqueeze(0).unsqueeze(-1).expand_as(x)
shift = (seq_len + (-1 * x_lens)).unsqueeze(-1).unsqueeze(-1).expand_as(x)
shift_idx = shift_idx + shift
shift_idx = shift_idx.clamp(0, seq_len - 1)
x = x.gather(1, inv_idx)
x = x.gather(1, shift_idx)
return x
def forward(self, logits, target):
"""
:param logits: tensor with shape of [batch_size, seq_len, input_size]
:param target: tensor with shape of [batch_size, seq_len] of Long type filled with indexes to gather from logits
:return: tensor with shape of [batch_size] with perplexity evaluation
"""
[batch_size, seq_len, input_size] = logits.size()
logits = logits.view(-1, input_size)
log_probs = F.log_softmax(logits)
del logits
log_probs = log_probs.view(batch_size, seq_len, input_size)
target = target.unsqueeze(2)
out = t.gather(log_probs, dim=2, index=target).squeeze(2).neg()
ppl = out.mean(1).exp()
return ppl
def eliminate_rows(self, prob_sc, ind, phis):
""" eliminate rows of phis and prob_matrix scale """
length = prob_sc.size()[1]
mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
.expand_as(mask)).
type(dtype))
ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
# permute prob_sc
m = mask.unsqueeze(2).expand_as(prob_sc)
mm = m.clone()
mm[:, :, 1:] = 0
prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
ind_sc.unsqueeze(2).expand_as(prob_sc)))
# compose permutations
ind = torch.gather(ind, 1, ind_sc)
active = torch.gather(1-mask, 1, ind_sc)
# permute phis
active1 = active.unsqueeze(2).expand_as(phis)
ind1 = ind.unsqueeze(2).expand_as(phis)
active2 = active.unsqueeze(1).expand_as(phis)
ind2 = ind.unsqueeze(1).expand_as(phis)
phis_out = torch.gather(phis, 1, ind1) * active1
phis_out = torch.gather(phis_out, 2, ind2) * active2
return prob_sc, ind, phis_out, active
def get_ranking(predictions, labels, num_guesses=5):
"""
Given a matrix of predictions and labels for the correct ones, get the number of guesses
required to get the prediction right per example.
:param predictions: [batch_size, range_size] predictions
:param labels: [batch_size] array of labels
:param num_guesses: Number of guesses to return
:return:
"""
assert labels.size(0) == predictions.size(0)
assert labels.dim() == 1
assert predictions.dim() == 2
values, full_guesses = predictions.topk(predictions.size(1), dim=1)
_, ranking = full_guesses.topk(full_guesses.size(1), dim=1, largest=False)
gt_ranks = torch.gather(ranking.data, 1, labels[:, None]).squeeze()
guesses = full_guesses[:, :num_guesses]
return gt_ranks, guesses
def compute_loss(self, batch, output, target):
""" See base class for args description. """
scores = self.generator(self.bottle(output))
gtruth = target.view(-1)
if self.confidence < 1:
tdata = gtruth.data
mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
if mask.dim() > 0:
likelihood.index_fill_(0, mask, 0)
tmp_.index_fill_(0, mask, 0)
gtruth = Variable(tmp_, requires_grad=False)
loss = self.criterion(scores, gtruth)
if self.confidence < 1:
loss_data = - likelihood.sum(0)
else:
loss_data = loss.data.clone()
stats = self.stats(loss_data, scores.data, target.view(-1).data)
return loss, stats
def prepare_batch(xs, lens, gpu=True):
lens, idx = torch.sort(lens, 0, True)
_, ridx = torch.sort(idx, 0)
idx_exp = idx.unsqueeze(0).unsqueeze(-1).expand_as(xs)
xs = torch.gather(xs, 1, idx_exp)
xs = Variable(xs, volatile=True)
lens = Variable(lens, volatile=True)
ridx = Variable(ridx, volatile=True)
if gpu:
xs = xs.cuda()
lens = lens.cuda()
ridx = ridx.cuda()
return xs, lens, ridx
def test_gather(self):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
actual = torch.gather(src, dim, idx)
expected = torch.Tensor().resize_(*idx_size)
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i,j,k]
expected[i,j,k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = torch.randn(3, 4, 5)
expected, idx = src.max(2)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def gather_index(input, index):
assert input.dim() == 2 and index.dim() == 1
index = index.unsqueeze(1).expand_as(input)
output = torch.gather(input, 1, index)
return output[:, 0]
def compute_loss(logits, y, lens):
batch_size, seq_len, vocab_size = logits.size()
logits = logits.view(batch_size * seq_len, vocab_size)
y = y.view(-1)
logprobs = F.log_softmax(logits)
losses = -torch.gather(logprobs, 1, y.unsqueeze(-1))
losses = losses.view(batch_size, seq_len)
mask = sequence_mask(lens, seq_len).float()
losses = losses * mask
loss_batch = losses.sum() / len(lens)
loss_step = losses.sum() / lens.sum().float()
return loss_batch, loss_step
def prepare_batch(self, batch_data, volatile=False):
x, x_lens, ys, ys_lens = batch_data
batch_dim = 0 if self.batch_first else 1
context_dim = 1 if self.batch_first else 0
x_lens, x_idx = torch.sort(x_lens, 0, True)
_, x_ridx = torch.sort(x_idx)
ys_lens, ys_idx = torch.sort(ys_lens, batch_dim, True)
x_ridx_exp = x_ridx.unsqueeze(context_dim).expand_as(ys_idx)
xys_idx = torch.gather(x_ridx_exp, batch_dim, ys_idx)
x = x[x_idx]
ys = torch.gather(ys, batch_dim, ys_idx.unsqueeze(-1).expand_as(ys))
x = Variable(x, volatile=volatile)
x_lens = Variable(x_lens, volatile=volatile)
ys_i = Variable(ys[..., :-1], volatile=volatile).contiguous()
ys_t = Variable(ys[..., 1:], volatile=volatile).contiguous()
ys_lens = Variable(ys_lens - 1, volatile=volatile)
xys_idx = Variable(xys_idx, volatile=volatile)
if self.is_cuda:
x = x.cuda(async=True)
x_lens = x_lens.cuda(async=True)
ys_i = ys_i.cuda(async=True)
ys_t = ys_t.cuda(async=True)
ys_lens = ys_lens.cuda(async=True)
xys_idx = xys_idx.cuda(async=True)
return x, x_lens, ys_i, ys_t, ys_lens, xys_idx
def enforce_angle(ang, xnorm, target, margin=0, linearized=False):
""" Enforce _real_ angular margin"""
m = margin + 1 # !! Just to keep parameters consistent w/ enforce_angle
tmp = torch.gather(ang, 1, target.view(-1, 1)).mul(m)
ang = ang.scatter(1, target.view(-1, 1), tmp)
ang = psi(ang, linearized)
ang = ang.mul(xnorm.view(-1, 1).expand_as(ang))
return ang
def enforce_angle(ang, xnorm, target, margin=0, linearized=False):
""" Enforce _real_ angular margin"""
m = margin + 1 # !! Just to keep parameters consistent w/ enforce_angle
tmp = torch.gather(ang, 1, target.view(-1, 1)).mul(m)
ang = ang.scatter(1, target.view(-1, 1), tmp)
ang = psi(ang, linearized)
ang = ang.mul(xnorm.view(-1, 1).expand_as(ang))
return ang
def _choose(self, lang_hs=None, words=None, sample=False):
# get all the possible choices
choices = self.domain.generate_choices(self.context)
# concatenate the list of the hidden states into one tensor
lang_hs = lang_hs if lang_hs is not None else torch.cat(self.lang_hs)
# concatenate all the words into one tensor
words = words if words is not None else torch.cat(self.words)
# logits for each of the item
logits = self.model.generate_choice_logits(words, lang_hs, self.ctx_h)
# construct probability distribution over only the valid choices
choices_logits = []
for i in range(self.domain.selection_length()):
idxs = [self.model.item_dict.get_idx(c[i]) for c in choices]
idxs = Variable(torch.from_numpy(np.array(idxs)))
idxs = self.model.to_device(idxs)
choices_logits.append(torch.gather(logits[i], 0, idxs).unsqueeze(1))
choice_logit = torch.sum(torch.cat(choices_logits, 1), 1, keepdim=False)
# subtract the max to softmax more stable
choice_logit = choice_logit.sub(choice_logit.max().data[0])
prob = F.softmax(choice_logit)
if sample:
# sample a choice
idx = prob.multinomial().detach()
logprob = F.log_softmax(choice_logit).gather(0, idx)
else:
# take the most probably choice
_, idx = prob.max(0, keepdim=True)
logprob = None
p_agree = prob[idx.data[0]]
# Pick only your choice
return choices[idx.data[0]][:self.domain.selection_length()], logprob, p_agree.data[0]
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def reverse_padded_sequence(inputs, lengths, batch_first=False):
"""Reverses sequences according to their lengths.
Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
``B x T x *`` if True. T is the length of the longest sequence (or larger),
B is the batch size, and * is any number of dimensions (including 0).
Arguments:
inputs (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequence lengths
batch_first (bool, optional): if True, inputs should be B x T x *.
Returns:
A Variable with the same size as inputs, but with each sequence
reversed according to its length.
"""
if not batch_first:
inputs = inputs.transpose(0, 1)
if inputs.size(0) != len(lengths):
raise ValueError('inputs incompatible with lengths.')
reversed_indices = [list(range(inputs.size(1)))
for _ in range(inputs.size(0))]
for i, length in enumerate(lengths):
if length > 0:
reversed_indices[i][:length] = reversed_indices[i][length-1::-1]
reversed_indices = (torch.LongTensor(reversed_indices).unsqueeze(2)
.expand_as(inputs))
reversed_indices = Variable(reversed_indices)
if inputs.is_cuda:
device = inputs.get_device()
reversed_indices = reversed_indices.cuda(device)
reversed_inputs = torch.gather(inputs, 1, reversed_indices)
if not batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def mdn_loss(gmm_params, mu, stddev, batchsize):
gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params)
eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
z = torch.add(mu, torch.mul(eps, stddev))
z_flat = z.repeat(1, args.nmix)
z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize)
gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize)
dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1))
dist_all = dist_all.view(batchsize, args.nmix)
dist_min, selectids = torch.min(dist_all, 1)
gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1))
gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min))
gmm_loss_l2 = torch.mean(dist_min)
return gmm_loss, gmm_loss_l2
def maskedCE(logits, target, length):
"""
Args:
logits: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = F.log_softmax(logits_flat)
# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losses_flat = -t.gather(log_probs_flat, dim=1, index=target_flat)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
masked_cross_entropy.py 文件源码
项目:Seq2Seq-on-Word-Sense-Disambiguition
作者: lbwbowenLi
项目源码
文件源码
阅读 24
收藏 0
点赞 0
评论 0
def masked_cross_entropy(logits, target, length):
length = Variable(torch.LongTensor(length)).cuda()
"""
Args:
logits: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = functional.log_softmax(logits_flat)
# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
def _test_gather(self, cast, test_bounds=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = torch.randn(m, n, o)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
TestTorch._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
src = cast(src)
idx = cast(idx)
actual = torch.gather(src, dim, idx)
expected = cast(torch.Tensor().resize_(*idx_size))
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, 0)
if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
src = cast(torch.randn(3, 4, 5))
expected, idx = src.max(2, True)
expected = cast(expected)
idx = cast(idx)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, 0)
def forward(self, lstm_out, lengths):
"""
Args:
lstm_out: A Variable containing a 3D tensor of dimension
(seq_len, batch_size, hidden_x_dirs)
lengths: A Variable containing 1D LongTensor of dimension
(batch_size)
Return:
A Variable containing a 2D tensor of the same type as lstm_out of
dim (batch_size, hidden_x_dirs) corresponding to the concatenated
last hidden states of the forward and backward parts of the input.
"""
seq_len = lstm_out.size(0)
batch_size = lstm_out.size(1)
hidden_x_dirs = lstm_out.size(2)
single_dir_hidden = hidden_x_dirs / 2
lengths_fw = lengths
lengths_bw = seq_len - lengths_fw
rep_lengths_fw = lengths_fw.view(1, batch_size, 1)
rep_lengths_fw = rep_lengths_fw.repeat(1, 1, single_dir_hidden)
rep_lengths_bw = lengths_bw.view(1, batch_size, 1)
rep_lengths_bw = rep_lengths_bw.repeat(1, 1, single_dir_hidden)
# we want 2 chunks in the last dimension
out_fw, out_bw = torch.chunk(lstm_out, 2, 2)
h_t_fw = torch.gather(out_fw, 0, rep_lengths_fw-1)
h_t_bw = torch.gather(out_bw, 0, rep_lengths_bw)
# -> (batch_size, hidden_x_dirs)
last_hidden_out = torch.cat([h_t_fw, h_t_bw], 2).squeeze()
return last_hidden_out
def sort_by_embeddings(self, Phis, Inputs_N, e):
ind = torch.sort(e, 1)[1].squeeze()
for i, phis in enumerate(Phis):
# rearange phis
phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
.expand_as(phis)))
Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
.expand_as(phis)))
# rearange inputs
Inputs_N[i] = torch.gather(Inputs_N[i], 1,
ind.unsqueeze(2).expand_as(Inputs_N[i]))
return Phis, Inputs_N
def combine_matrices(self, prob_matrix, prob_matrix_scale, perm):
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, perm
def outputs(self, input, prob_matrix, perm):
hard_output = (torch.gather(input, 1, perm.unsqueeze(2)
.expand_as(input)))
# soft argmax
soft_output = torch.bmm(prob_matrix, input)
return hard_output, soft_output
def combine_matrices(self, prob_matrix, prob_matrix_scale,
perm, last=False):
# prob_matrix shape is bs x length x length + 1. Add extra column.
length = prob_matrix_scale.size()[2]
first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype)
first[:, 0, 0] = 1.0
prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1)
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
# combine
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, prob_matrix_scale, perm
def outputs(self, input, prob_matrix, perm):
hard_output = (torch.gather(input, 1, perm.unsqueeze(2)
.expand_as(input)))
# soft argmax
soft_output = torch.bmm(prob_matrix, input)
return hard_output, soft_output
def deploy(x, labels):
pred = m(x)
loss = crit(pred, labels)
values, bests = pred.topk(pred.size(1), dim=1)
_, ranking = bests.topk(bests.size(1), dim=1, largest=False) # [batch_size, dict_size]
rank = torch.gather(ranking.data, 1, labels.data[:, None]).cpu().numpy().squeeze()
top5_preds = bests[:, :5].cpu().data.numpy()
top1_acc = np.mean(rank==0)
top5_acc = np.mean(rank<5)
return loss.data[0], top1_acc, top5_acc
def devise_train(m, x, labels, data, att_crit=None, optimizers=None):
"""
Train the direct attribute prediction model
:param m: Model we're using
:param x: [batch_size, 3, 224, 224] Image input
:param labels: [batch_size] variable with indices of the right verbs
:param embeds: [vocab_size, 300] Variables with embeddings of all of the verbs
:param atts_matrix: [vocab_size, att_dim] matrix with GT attributes of the verbs
:param att_crit: AttributeLoss module that computes the loss
:param optimizers: the decorator will use these to update parameters
:return:
"""
# Make embed unit normed
embed_normed = _normalize(data.attributes.embeds)
mv_image = m(x).embed_pred
tmv_image = mv_image @ embed_normed.t()
# Use a random label from the same batch
correct_contrib = torch.gather(tmv_image, 1, labels[:,None])
# Should be fine to ignore where the correct contrib intersects because the gradient
# wrt input is 0
losses = (0.1 + tmv_image - correct_contrib.expand_as(tmv_image)).clamp(min=0.0)
# losses.scatter_(1, labels[:, None], 0.0)
loss = m.l2_penalty + losses.sum(1).squeeze().mean()
return loss