def convolve1d_4D(input, W, mode='full'):
batch_size, nchannels, nwords, ndim = input.shape
nkernels_out, nkernels_in, filter_width, ndim = W.shape
# Unroll filter along columns
W_unrolled = W.dimshuffle(0, 2, 1, 3).flatten(ndim=3)
# Replicate input filters 'batch_size' times and squash out_filters along column axis.
# W_tiled = T.tile(W_unrolled, (1, 1, batch_size)).dimshuffle(1, 0, 2).flatten(ndim=2) # doesn't give a gradient
W_tiled = T.alloc(W_unrolled, batch_size, W_unrolled.shape[0], W_unrolled.shape[1], W_unrolled.shape[2]).dimshuffle(1, 2, 0, 3).flatten(ndim=3).dimshuffle(1, 0, 2).flatten(ndim=2)
# Unroll input and pad to fit the output filters.
input_reshaped = input.dimshuffle(0, 2, 1, 3).flatten(ndim=3).dimshuffle(1,0,2).flatten(ndim=2)
# input_tiled = T.tile(input_reshaped, (1, nkernels_out))
input_tiled = T.alloc(input_reshaped, nkernels_out, input_reshaped.shape[0], input_reshaped.shape[1]).dimshuffle(1, 0, 2).flatten(ndim=2)
conv_res = convolve1d_2D(input_tiled, W_tiled, mode=mode)
if mode == 'full':
new_shape = (nwords+filter_width-1, nkernels_out, batch_size, nkernels_in, ndim)
elif mode == 'valid':
new_shape = (nwords-filter_width+1, nkernels_out, batch_size, nkernels_in, ndim)
conv_out = conv_res.reshape(new_shape).dimshuffle(2, 1, 0, 3, 4).sum(axis=3)
return conv_out
##########################################
### Using einsum for 4d matrices
##########################################
评论列表
文章目录