def test_waveform_data_source(self):
ds = WaveformDataSource(FileDataSource(DUMMY_DATA_PATH, suffix='.wav'), process_waveform=dummy_process_waveforms)
self.assertTrue(
np.all(
ds['1_sad_kid_1'] == np.array(2)
)
)
paths = [os.path.join(DUMMY_DATA_PATH, f) for f in os.listdir(DUMMY_DATA_PATH) if f.endswith('.wav')]
filenames = [x.split(os.sep)[-1].split('.')[0] for x in paths]
npt.assert_array_equal(
np.array([ds[f] for f in filenames]),
np.array([dummy_process_waveforms(p)[1] for p in paths])
)
评论列表
文章目录