def dilate(sigs, dilation):
"""
Note this will fail if the dilation doesn't allow a whole number amount of padding
:param x: Tensor or Variable of size (N, L, C), where N is the input dilation, C is the number of channels, and L is the input length
:param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
:param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
:return: The dilated Tensor or Variable of size (dilation, C, L*N / dilation). The output might be zero padded at the start
"""
n, c, l = sigs.size()
dilation_factor = dilation / n
if dilation_factor == 1:
return sigs, 0.
# zero padding for reshaping
new_n = int(dilation)
new_l = int(np.ceil(l*n/dilation))
pad_len = (new_n*new_l-n*l)/n
if pad_len > 0:
print("Padding: {}, {}, {}".format(new_n, new_l, pad_len))
# TODO pad output tensor unevenly for indivisible dilations
assert pad_len == int(pad_len)
# "squeeze" then "unsqueeze" due to limitation of pad function
# which only works with 4d/5d tensors
padding = (int(pad_len), 0, 0, 0) # (d3_St, d3_End, d2_St, d2_End), d0 and d1 unpadded
sigs = pad1d(sigs, padding)
# reshape according to dilation
sigs = sigs.permute(1, 2, 0).contiguous() # (n, c, l) -> (c, l, n)
sigs = sigs.view(c, new_l, new_n)
sigs = sigs.permute(2, 0, 1).contiguous() # (c, l, n) -> (n, c, l)
return sigs, pad_len
评论列表
文章目录