testing.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pyMTL 作者: bibliolytic 项目源码 文件源码
def test_temporal_learning(flen=[5]):
    '''
    Function that computes CSP filters then uses those with the temporal filter MTL 
    idea, and confirms that the output has a spectral profile that is similar to expected.
    Generate y values from trace of filtered covariance to ensure that is not an issue
    '''
    def covmat(x,y):
        return (1/(x.shape[1]-1)*x.dot(y.T))
    D = scio.loadmat('/is/ei/vjayaram/code/python/pyMTL_MunichMI.mat')
    data = D['T'].ravel()
    labels = D['Y'].ravel()
    fig = plt.figure()
    fig.suptitle('Recovered frequency filters for various filter lengths')
    model = TemporalBRC(max_prior_iter=100)
    # compute cross-covariance matrices and generate X
    for find,freq in enumerate(flen):
        X = []
        for tind,d in enumerate(data):
            d = d.transpose((2,0,1))
            x = np.zeros((d.shape[0], freq))
            nsamples = d.shape[2]
            for ind, t in enumerate(d):
                for j in range(freq):
                    C = covmat(t[:,0:(nsamples-j)],t[:,j:])
                    x[ind,j] = np.trace(C + C.T)
            X.append(x)
            labels[tind] = labels[tind].ravel()

        model.fit_multi_task(X,labels)
        b = numerics.solve_fir_coef(model.prior.mu.flatten())[0]
        # look at filter properties
        w,h = freqz(b[1:],worN=100)
        w = w*500/2/np.pi # convert to Hz

        ax1 = fig.add_subplot(3,3,find+1)
        plt.plot(w, 20 * np.log10(abs(h)), 'b')
        plt.ylabel('Amplitude [dB]', color='b')
        plt.xlabel('Frequency [Hz]')
        ax2 = ax1.twinx()
        angles = np.unwrap(np.angle(h))
        plt.plot(w, angles, 'g')
        plt.ylabel('Angle (radians)', color='g')
        plt.grid()
        plt.axis('tight')
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号