ops.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:WaveNet 作者: ritheshkumar95 项目源码 文件源码
def conv1d(name,input,kernel,stride,n_filters,depth,bias=False,batchnorm=False,pad='valid',filter_dilation=(1,1),run_mode=0):
    W = lib.param(
        name+'.W',
        lasagne.init.HeNormal().sample((n_filters,depth,kernel,1)).astype('float32')
        )

    out = T.nnet.conv2d(input,W,subsample=(stride,1),border_mode=pad,filter_dilation=filter_dilation)

    if bias:
        b = lib.param(
            name + '.b',
            np.zeros(n_filters).astype('float32')
            )

        out += b[None,:,None,None]

    if batchnorm:
        out = BatchNorm(name,out,n_filters,mode=1,run_mode=run_mode)

    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号