def schedule_sampling(self, prev, dec_out):
"""
Resample n inputs to next iteration from the model itself. N is itself
sampled from a bernoulli independently for each example in the batch
with weights equal to the model's variable self.scheduled_rate.
Parameters:
-----------
- prev: torch.LongTensor(batch_size)
- dec_out: torch.Tensor(batch_size x hid_dim)
Returns: partially resampled input
--------
- prev: torch.LongTensor(batch_size)
"""
prev, dec_out = prev.data, dec_out.data # don't register computation
keep_mask = torch.bernoulli(
torch.zeros_like(prev).float() + self.exposure_rate) == 1
# return if no sampling is necessary
if len(keep_mask.nonzero()) == len(prev):
return prev
sampled = self.decoder.project(
Variable(dec_out, volatile=True)).max(1)[1].data
if keep_mask.nonzero().dim() == 0: # return all sampled
return sampled
keep_mask = keep_mask.nonzero().squeeze(1)
sampled[keep_mask] = prev[keep_mask]
return sampled
评论列表
文章目录