conv1d.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DBQA-KBQA 作者: Lucien-qiang 项目源码 文件源码
def convolve1d_4D_scan(input, W, mode='full'):
  batch_size, nchannels, nwords, ndim = input.shape
  nkernels_out, nkernels_in, filter_width, ndim = W.shape

  # Unroll filter along columns
  W_unrolled = W.dimshuffle(0, 2, 1, 3).flatten(ndim=3)
  # Replicate input filters 'batch_size' times and squash out_filters along column axis.
  # W_tiled = T.tile(W_unrolled, (1, 1, batch_size)).dimshuffle(1, 0, 2).flatten(ndim=2)  # doesn't give a gradient
  W_tiled = T.alloc(W_unrolled, batch_size, W_unrolled.shape[0], W_unrolled.shape[1], W_unrolled.shape[2]).dimshuffle(1, 2, 0, 3).flatten(ndim=3).dimshuffle(1, 0, 2).flatten(ndim=2)
  W_tiled = W_tiled[::-1]
  # reverse_slicing = [slice(None, None, None)] * W_tiled.ndim
  # reverse_slicing[0] = slice(None, None, -1)
  # reverse_slicing = tuple(reverse_slicing)
  # W_tiled = W_tiled[reverse_slicing]  # flip the kernel

  # Unroll input and pad to fit the output filters.
  input_reshaped = input.dimshuffle(0, 2, 1, 3).flatten(ndim=3).dimshuffle(1,0,2).flatten(ndim=2)
  # input_tiled = T.tile(input_reshaped, (1, nkernels_out))
  input_tiled = T.alloc(input_reshaped, nkernels_out, input_reshaped.shape[0], input_reshaped.shape[1]).dimshuffle(1, 0, 2).flatten(ndim=2)

  if mode == 'full':
    pad = T.zeros((filter_width-1, nkernels_out*batch_size*nchannels*ndim))
    input_padded = T.concatenate([pad, input_tiled, pad])
    conv_out, _ = theano.scan(fn=lambda i: (W_tiled * input_padded[i:i+filter_width]).sum(axis=0),
                              outputs_info=None,
                              sequences=[T.arange(0, nwords+filter_width-1)])
    new_shape = (nwords+filter_width-1, nkernels_out, batch_size, nkernels_in, ndim)
  elif mode == 'valid':
    conv_out, _ = theano.scan(fn=lambda i: (W_tiled * input_tiled[i:i+filter_width]).sum(axis=0),
                              outputs_info=None,
                              sequences=[T.arange(0, nwords-filter_width+1)])
    new_shape = (nwords-filter_width+1, nkernels_out, batch_size, nkernels_in, ndim)

  conv_reshaped = conv_out.reshape(new_shape).dimshuffle(2, 1, 0, 3, 4).sum(axis=3)
  return conv_reshaped
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号