def expand_z_where(z_where):
# Take a batch of three-vectors, and massages them into a batch of
# 2x3 matrices with elements like so:
# [s,x,y] -> [[s,0,x],
# [0,s,y]]
n = z_where.size(0)
out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where), 1)
ix = Variable(expansion_indices)
if z_where.is_cuda:
ix = ix.cuda()
out = torch.index_select(out, 1, ix)
out = out.view(n, 2, 3)
return out
# Scaling by `1/scale` here is unsatisfactory, as `scale` could be
# zero.
评论列表
文章目录