def create_network():
dp = tflearn.data_preprocessing.DataPreprocessing()
dp.add_featurewise_zero_center()
dp.add_featurewise_stdnorm()
#dp.add_samplewise_zero_center()
#dp.add_samplewise_stdnorm()
network = tflearn.input_data(shape=[None, chunk_size])#, data_preprocessing=dp)
# input is a real signal
network = tf.complex(network, 0.0)
# fft the input
input_fft = tf.fft(network)
input_orig_fft = input_fft
input_fft = tf.stack([tf.real(input_fft), tf.imag(input_fft)], axis=2)
fft_size = int(input_fft.shape[1])
network = input_fft
print("fft shape: " + str(input_fft.get_shape()))
omg = fft_size
nn_reg = None
mask = network
mask = tflearn.layers.fully_connected(mask, omg*2, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
mask = tflearn.layers.fully_connected(mask, omg, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
mask = tflearn.layers.fully_connected(mask, omg/2, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
#mask = tflearn.layers.fully_connected(mask, omg/4, activation="tanh")
mask = tflearn.reshape(mask, [-1, 1, omg/2])
mask = tflearn.layers.recurrent.lstm(mask, omg/4)
mask = tflearn.layers.fully_connected(mask, omg/2, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
mask = tflearn.layers.fully_connected(mask, omg, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
mask = tflearn.layers.fully_connected(mask, omg*2, activation="tanh", regularizer=nn_reg)
mask = tflearn.layers.normalization.batch_normalization(mask)
mask = tflearn.layers.fully_connected(mask, omg, activation="sigmoid", regularizer=nn_reg)
real = tf.multiply(tf.real(input_orig_fft), mask)
imag = tf.multiply(tf.imag(input_orig_fft), mask)
network = tf.real(tf.ifft(tf.complex(real, imag)))
print("final shape: " + str(network.get_shape()))
network = tflearn.regression(network, optimizer="adam", learning_rate=learning_rate, loss="mean_square")
return network
评论列表
文章目录