def fft_circ_conv1d(X, keys, batch_size, num_copies, num_keys=None, conj=False):
if conj:
keys = HolographicMemory.conj_real_by_complex(keys)
# Get our original shapes
xshp = X.get_shape().as_list()
kshp = keys.get_shape().as_list()
kshp[0] = num_keys if num_keys is not None else kshp[0]
kshp[1] = xshp[1] if kshp[1] is None else kshp[1]
print 'X : ', xshp, ' | keys : ', kshp, ' | batch_size = ', batch_size
# duplicate out input data by the ratio: number_keys / batch_size
# eg: |input| = [2, 784] ; |keys| = 3*[2, 784] ; (3 is the num_copies)
# |new_input| = 6/2 |input| = [input; input; input]
#
# At test: |memories| = [3, 784] ; |keys| = 3*[n, 784] ;
# |new_input| = 3n / 3 = n [where n is the number of desired parallel retrievals]
num_dupes = kshp[0] / batch_size
print 'num dupes = ', num_dupes
xcplx = HolographicMemory.split_to_complex(tf.tile(X, [num_dupes, 1]) \
if num_dupes > 1 else X)
xshp = xcplx.get_shape().as_list()
kcplx = HolographicMemory.split_to_complex(keys, kshp)
# Convolve & re-cast to a real valued function
unsplit_func = HolographicMemory.unsplit_from_complex_ri if not conj \
else HolographicMemory.unsplit_from_complex_ir
#fft_mul = HolographicMemory.bound(tf.mul(tf.fft(xcplx), tf.fft(kcplx)))
fft_mul = tf.mul(tf.fft(xcplx), tf.fft(kcplx))
conv = unsplit_func(tf.ifft(fft_mul))
print 'full conv = ', conv.get_shape().as_list()
batch_iter = min(batch_size, xshp[0]) if xshp[0] is not None else batch_size
print 'batch = ', batch_size, ' | num_copies = ', num_copies, '| num_keys = ', num_keys, \
'| xshp[0] = ', xshp[0], ' | len(keys) = ', kshp[0], ' | batch iter = ', batch_iter
conv_concat = [tf.expand_dims(tf.reduce_mean(conv[begin:end], 0), 0)
for begin, end in zip(range(0, kshp[0], batch_iter),
range(batch_iter, kshp[0]+1, batch_iter))]
print 'conv concat = ', len(conv_concat), ' x ', conv_concat[0].get_shape().as_list()
# return a single concatenated tensor:
# C = [c0; c1; ...]
C = tf.concat(0, conv_concat)
return C
#C = tf_mean_std_normalize(C)
#return C / tf.maximum(tf.reduce_max(C), 1e-20)
#return tf.nn.sigmoid(C)
#return tf_mean_std_normalize(C)
评论列表
文章目录