def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4):
batch, length, channels = list(x.size())
assert (channels % 2 == 0)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(float(num_timescales) - 1.))
position = torch.arange(0, length).float()
inv_timescales = torch.arange(0, num_timescales).float()
if x.is_cuda:
position = position.cuda()
inv_timescales = inv_timescales.cuda()
inv_timescales.mul_(-log_timescale_increment).exp_().mul_(min_timescale)
scaled_time = position.unsqueeze(1).expand(
length, num_timescales) * inv_timescales.unsqueeze(0).expand(length, num_timescales)
# scaled time is now length x num_timescales
# length x channels
signal = torch.cat([scaled_time.sin(), scaled_time.cos()], 1)
return signal.unsqueeze(0).expand(batch, length, channels)
评论列表
文章目录