def get_batch(batch_size):
samples = np.zeros([batch_size, sample_length])
frequencies = [set()] * batch_size
ffts = np.zeros([batch_size, fft_size])
for i in range(batch_size):
num_sources = np.random.randint(min_sources, max_sources + 1)
for source_idx in range(num_sources):
frequency, sample = generate_sample()
samples[i] += sample
frequencies[i].add(frequency)
samples[i] /= float(num_sources)
fft = np.fft.rfft(samples[i], norm="ortho")
fft = np.real(fft)**2 + np.imag(fft)**2
fft *= fft_norm
ffts[i] = fft
return frequencies, samples, ffts
评论列表
文章目录