def convolve1d_4D_scan(input, W, mode='full'):
batch_size, nchannels, nwords, ndim = input.shape
nkernels_out, nkernels_in, filter_width, ndim = W.shape
# Unroll filter along columns
W_unrolled = W.dimshuffle(0, 2, 1, 3).flatten(ndim=3)
# Replicate input filters 'batch_size' times and squash out_filters along column axis.
# W_tiled = T.tile(W_unrolled, (1, 1, batch_size)).dimshuffle(1, 0, 2).flatten(ndim=2) # doesn't give a gradient
W_tiled = T.alloc(W_unrolled, batch_size, W_unrolled.shape[0], W_unrolled.shape[1], W_unrolled.shape[2]).dimshuffle(1, 2, 0, 3).flatten(ndim=3).dimshuffle(1, 0, 2).flatten(ndim=2)
W_tiled = W_tiled[::-1]
# reverse_slicing = [slice(None, None, None)] * W_tiled.ndim
# reverse_slicing[0] = slice(None, None, -1)
# reverse_slicing = tuple(reverse_slicing)
# W_tiled = W_tiled[reverse_slicing] # flip the kernel
# Unroll input and pad to fit the output filters.
input_reshaped = input.dimshuffle(0, 2, 1, 3).flatten(ndim=3).dimshuffle(1,0,2).flatten(ndim=2)
# input_tiled = T.tile(input_reshaped, (1, nkernels_out))
input_tiled = T.alloc(input_reshaped, nkernels_out, input_reshaped.shape[0], input_reshaped.shape[1]).dimshuffle(1, 0, 2).flatten(ndim=2)
if mode == 'full':
pad = T.zeros((filter_width-1, nkernels_out*batch_size*nchannels*ndim))
input_padded = T.concatenate([pad, input_tiled, pad])
conv_out, _ = theano.scan(fn=lambda i: (W_tiled * input_padded[i:i+filter_width]).sum(axis=0),
outputs_info=None,
sequences=[T.arange(0, nwords+filter_width-1)])
new_shape = (nwords+filter_width-1, nkernels_out, batch_size, nkernels_in, ndim)
elif mode == 'valid':
conv_out, _ = theano.scan(fn=lambda i: (W_tiled * input_tiled[i:i+filter_width]).sum(axis=0),
outputs_info=None,
sequences=[T.arange(0, nwords-filter_width+1)])
new_shape = (nwords-filter_width+1, nkernels_out, batch_size, nkernels_in, ndim)
conv_reshaped = conv_out.reshape(new_shape).dimshuffle(2, 1, 0, 3, 4).sum(axis=3)
return conv_reshaped
评论列表
文章目录