def sample(self, fc_feats, att_feats, opt={}):
sample_max = opt.get('sample_max', 1)
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
if beam_size > 1:
return self.sample_beam(fc_feats, att_feats, opt)
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
seq = []
seqLogprobs = []
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats)
else:
if t == 1: # input <bos>
it = fc_feats.data.new(batch_size).long().zero_()
elif sample_max:
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else:
if temperature == 1.0:
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
else:
# scale logprobs by temperature
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
it = torch.multinomial(prob_prev, 1).cuda()
sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions
it = it.view(-1).long() # and flatten indices for downstream processing
xt = self.embed(Variable(it, requires_grad=False))
if t >= 2:
# stop when all finished
if t == 2:
unfinished = it > 0
else:
unfinished = unfinished * (it > 0)
if unfinished.sum() == 0:
break
it = it * unfinished.type_as(it)
seq.append(it) #seq[t] the input of t+2 time step
seqLogprobs.append(sampleLogprobs.view(-1))
output, state = self.core(xt, state)
logprobs = F.log_softmax(self.logit(output))
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
评论列表
文章目录