run.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:returnn-benchmarks 作者: rwth-i6 项目源码 文件源码
def lstm_unroll(num_lstm_layer, 
                num_hidden, dropout=0.,
                concat_decode=True, use_loss=False):
    """unrolled lstm network"""
    with mx.AttrScope(ctx_group='decode'):
        cls_weight = mx.sym.Variable("cls_weight")
        cls_bias = mx.sym.Variable("cls_bias")

    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        with mx.AttrScope(ctx_group='layer%d' % i):
            param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
                                         i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
                                         h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
                                         h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i)))
            state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                              h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)

    # stack LSTM
    hidden = mx.sym.SliceChannel(data=mx.sym.Variable("data"), num_outputs=MAX_LEN, squeeze_axis=0)
    for i in range(num_lstm_layer):
       next_hidden = []
       for t in range(MAX_LEN):
         with mx.AttrScope(ctx_group='layer%d' % i):
           next_state = lstm(n_hidden, indata=hidden[t],
                              prev_state=last_states[i],
                              param=param_cells[i],
                              layeridx=i, dropout=0.)
           next_hidden.append(next_state.h)
           last_states[i] = next_state
       hidden = next_hidden[:]

    sm = []
    labels = mx.sym.SliceChannel(data=mx.sym.Variable("labels"), num_outputs=MAX_LEN, squeeze_axis=0)
    for t in range(MAX_LEN):
      fc = mx.sym.FullyConnected(data=hidden[t],
                               weight=cls_weight,
                               bias=cls_bias,
                               num_hidden=n_classes)
      sm.append(mx.sym.softmax_cross_entropy(fc, labels[t], name="sm"))

    for i in range(num_lstm_layer):
        state = last_states[i]
        state = LSTMState(c=mx.sym.BlockGrad(state.c, name="l%d_last_c" % i),
                          h=mx.sym.BlockGrad(state.h, name="l%d_last_h" % i))
        last_states[i] = state

    unpack_c = [state.c for state in last_states]
    unpack_h = [state.h for state in last_states]
    list_all = sm + unpack_c + unpack_h
    return mx.sym.Group(list_all)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号