layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:fast-wavenet.pytorch 作者: dhpollack 项目源码 文件源码
def dilate(sigs, dilation):
    """

    Note this will fail if the dilation doesn't allow a whole number amount of padding

    :param x: Tensor or Variable of size (N, L, C), where N is the input dilation, C is the number of channels, and L is the input length
    :param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
    :param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
    :return: The dilated Tensor or Variable of size (dilation, C, L*N / dilation). The output might be zero padded at the start
    """

    n, c, l = sigs.size()
    dilation_factor = dilation / n
    if dilation_factor == 1:
        return sigs, 0.

    # zero padding for reshaping
    new_n = int(dilation)
    new_l = int(np.ceil(l*n/dilation))
    pad_len = (new_n*new_l-n*l)/n
    if pad_len > 0:
        print("Padding: {}, {}, {}".format(new_n, new_l, pad_len))
        # TODO pad output tensor unevenly for indivisible dilations
        assert pad_len == int(pad_len)
        # "squeeze" then "unsqueeze" due to limitation of pad function
        # which only works with 4d/5d tensors
        padding = (int(pad_len), 0, 0, 0) # (d3_St, d3_End, d2_St, d2_End), d0 and d1 unpadded
        sigs = pad1d(sigs, padding)

    # reshape according to dilation
    sigs = sigs.permute(1, 2, 0).contiguous()  # (n, c, l) -> (c, l, n)
    sigs = sigs.view(c, new_l, new_n)
    sigs = sigs.permute(2, 0, 1).contiguous()  # (c, l, n) -> (n, c, l)

    return sigs, pad_len
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号