filters_bank.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pyscatwave 作者: edouardoyallon 项目源码 文件源码
def filters_bank(M, N, J, L=8):
    filters = {}
    filters['psi'] = []


    offset_unpad = 0
    for j in range(J):
        for theta in range(L):
            psi = {}
            psi['j'] = j
            psi['theta'] = theta
            psi_signal = morlet_2d(M, N, 0.8 * 2**j, (int(L-L/2-1)-theta) * np.pi / L, 3.0 / 4.0 * np.pi /2**j,offset=offset_unpad) # The 5 is here just to match the LUA implementation :)
            psi_signal_fourier = fft.fft2(psi_signal)
            for res in range(j + 1):
                psi_signal_fourier_res = crop_freq(psi_signal_fourier, res)
                psi[res]=torch.FloatTensor(np.stack((np.real(psi_signal_fourier_res), np.imag(psi_signal_fourier_res)), axis=2))
                # Normalization to avoid doing it with the FFT!
                psi[res].div_(M*N// 2**(2*j))
            filters['psi'].append(psi)

    filters['phi'] = {}
    phi_signal = gabor_2d(M, N, 0.8 * 2**(J-1), 0, 0, offset=offset_unpad)
    phi_signal_fourier = fft.fft2(phi_signal)
    filters['phi']['j'] = J
    for res in range(J):
        phi_signal_fourier_res = crop_freq(phi_signal_fourier, res)
        filters['phi'][res]=torch.FloatTensor(np.stack((np.real(phi_signal_fourier_res), np.imag(phi_signal_fourier_res)), axis=2))
        filters['phi'][res].div_(M*N // 2 ** (2 * J))

    return filters
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号