baseline_gaussian.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:speech 作者: igul222 项目源码 文件源码
def sample_level_rnn(input_sequences, h0, reset):
    """
    input_sequences.shape: (batch size, seq len)
    h0.shape:              (batch size, N_GRUS, DIM)
    reset.shape:           ()
    output.shape:          (batch size, seq len, Q_LEVELS)
    """

    learned_h0 = lib.param(
        'SampleLevel.h0',
        numpy.zeros((N_GRUS, DIM), dtype=theano.config.floatX)
    )
    learned_h0 = T.alloc(learned_h0, h0.shape[0], N_GRUS, DIM)
    h0 = theano.ifelse.ifelse(reset, learned_h0, h0)

    # Embedded inputs
    #################

    # FRAME_SIZE = Q_LEVELS
    # frames = lib.ops.Embedding('SampleLevel.Embedding', Q_LEVELS, Q_LEVELS, input_sequences)

    # Real-valued inputs
    ####################

    # 'frames' of size 1
    FRAME_SIZE = 1
    frames = input_sequences.reshape((
        input_sequences.shape[0],
        input_sequences.shape[1],
        1
    ))
    # # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2]
    # # (a reasonable range to pass as inputs to the RNN)
    # frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1)
    # frames *= lib.floatX(2)

    gru0 = lib.ops.LowMemGRU('SampleLevel.GRU0', FRAME_SIZE, DIM, frames, h0=h0[:, 0])
    # gru0 = T.nnet.relu(lib.ops.Linear('SampleLevel.GRU0FF', DIM, DIM, gru0, initialization='he'))
    grus = [gru0]
    for i in xrange(1, N_GRUS):
        gru = lib.ops.LowMemGRU('SampleLevel.GRU'+str(i), DIM, DIM, grus[-1], h0=h0[:, i])
        # gru = T.nnet.relu(lib.ops.Linear('SampleLevel.GRU'+str(i)+'FF', DIM, DIM, gru, initialization='he'))
        grus.append(gru)

    # We apply the softmax later
    output = lib.ops.Linear(
        'Output',
        N_GRUS*DIM,
        2,
        T.concatenate(grus, axis=2)
    )
    # output = lib.ops.Linear(
    #     'Output',
    #     DIM,
    #     Q_LEVELS,
    #     grus[-1]
    # )

    last_hidden = T.stack([gru[:,-1] for gru in grus], axis=1)

    return (output, last_hidden)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号