test_model_io.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:ip-avsr 作者: lzuwei 项目源码 文件源码
def test_load_params(self):
        window = T.iscalar('theta')
        inputs1 = T.tensor3('inputs1', dtype='float32')
        mask = T.matrix('mask', dtype='uint8')
        network = deltanet_majority_vote.load_saved_model('../oulu/results/best_models/1stream_mfcc_w3s3.6.pkl',
                                                          ([500, 200, 100, 50], [rectify, rectify, rectify, linear]),
                                                          (None, None, 91), inputs1, (None, None), mask,
                                                          250, window, 10)
        d = deltanet_majority_vote.extract_encoder_weights(network, ['fc1', 'fc2', 'fc3', 'bottleneck'],
                                                           [('w1', 'b1'), ('w2', 'b2'), ('w3', 'b3'), ('w4', 'b4')])
        b = deltanet_majority_vote.extract_lstm_weights(network, ['f_blstm1', 'b_blstm1'],
                                                        ['flstm', 'blstm'])
        expected_keys = ['w1', 'w2', 'w3', 'w4', 'b1', 'b2', 'b3', 'b4']
        keys = d.keys()
        for k in keys:
            assert k in expected_keys
            assert type(d[k]) == np.ndarray
        save_mat(d, '../oulu/models/oulu_1stream_mfcc_w3s3.mat')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号