ops.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:mimicry.ai 作者: fizerkhan 项目源码 文件源码
def conv1d(
    name,
    input,
    input_dim,
    output_dim,
    filter_size,
    init = 'glorot',
    non_linearity = 'relu',
    bias = True
    ):

    import lasagne

    inp = input.dimshuffle(0,2,1,'x')

    if init == 'glorot':
        initializer = lasagne.init.GlorotUniform()
    elif init == 'he':
        initializer = lasagne.init.HeUniform()

    if non_linearity == 'gated':
        num_filters = 2*output_dim
    else:
        num_filters = output_dim

    W_shape = (num_filters, input_dim, filter_size, 1)

    if bias:
        bias_shape = (num_filters,)

    W = lib.param(name+".W", initializer.sample(W_shape))

    if bias:
        b = lib.param(name+".b", lasagne.init.Constant(0.).sample(bias_shape))

    conv_out = T.nnet.conv2d(
                    inp,  W,
                    filter_flip= False,
                    border_mode = 'valid'
                )

    if bias:
        conv_out = conv_out + b[None,:,None, None]

    if non_linearity == 'gated':
        activation = gated_non_linerity
    elif non_linearity == 'relu':
        activation = T.nnet.relu
    elif non_linearity == 'elu':
        activation = lambda x : T.switch( x >= 0., x, T.exp(x) - floatX(1.))
    elif non_linearity == 'identity':
        activation = lambda x: x
    else:
        raise NotImplementedError("{} non-linearity not implemented!".format(non_linearity))

    output = conv_out

    output = output.reshape((output.shape[0], output.shape[1], output.shape[2]))
    output = output.dimshuffle(0,2,1)

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号