conv1d.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DEEP-CLICK-MODEL 作者: THUIR 项目源码 文件源码
def convolve1d_4D(input, W, mode='full'):
    batch_size, nchannels, nwords, ndim = input.shape
    nkernels_out, nkernels_in, filter_width, ndim = W.shape
    # Unroll filter along columns
    W_unrolled = W.dimshuffle(0, 2, 1, 3).flatten(ndim=3)

    # Replicate input filters 'batch_size' times and squash out_filters along column axis.
    # W_tiled = T.tile(W_unrolled, (1, 1, batch_size)).dimshuffle(1, 0, 2).flatten(ndim=2)  # doesn't give a gradient
    W_tiled = T.alloc(W_unrolled, batch_size, W_unrolled.shape[0], W_unrolled.shape[1], W_unrolled.shape[2]).dimshuffle(1, 2, 0, 3).flatten(ndim=3).dimshuffle(1, 0, 2).flatten(ndim=2)

    # Unroll input and pad to fit the output filters.
    input_reshaped = input.dimshuffle(0, 2, 1, 3).flatten(ndim=3).dimshuffle(1,0,2).flatten(ndim=2)
    # input_tiled = T.tile(input_reshaped, (1, nkernels_out))
    input_tiled = T.alloc(input_reshaped, nkernels_out, input_reshaped.shape[0], input_reshaped.shape[1]).dimshuffle(1, 0, 2).flatten(ndim=2)

    conv_res = convolve1d_2D(input_tiled, W_tiled, mode=mode)
    if mode == 'full':
      new_shape = (nwords+filter_width-1, nkernels_out, batch_size, nkernels_in, ndim)
    elif mode == 'valid':
      new_shape = (nwords-filter_width+1, nkernels_out, batch_size, nkernels_in, ndim)

    conv_out = conv_res.reshape(new_shape).dimshuffle(2, 1, 0, 3, 4).sum(axis=3)
    return conv_out

##########################################
### Using einsum for 4d matrices
##########################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号