def lstm_unroll(num_lstm_layer,
num_hidden, dropout=0.,
concat_decode=True, use_loss=False):
"""unrolled lstm network"""
with mx.AttrScope(ctx_group='decode'):
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
param_cells = []
last_states = []
for i in range(num_lstm_layer):
with mx.AttrScope(ctx_group='layer%d' % i):
param_cells.append(LSTMParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
# stack LSTM
hidden = mx.sym.SliceChannel(data=mx.sym.Variable("data"), num_outputs=MAX_LEN, squeeze_axis=0)
for i in range(num_lstm_layer):
next_hidden = []
for t in range(MAX_LEN):
with mx.AttrScope(ctx_group='layer%d' % i):
next_state = lstm(n_hidden, indata=hidden[t],
prev_state=last_states[i],
param=param_cells[i],
layeridx=i, dropout=0.)
next_hidden.append(next_state.h)
last_states[i] = next_state
hidden = next_hidden[:]
sm = []
labels = mx.sym.SliceChannel(data=mx.sym.Variable("labels"), num_outputs=MAX_LEN, squeeze_axis=0)
for t in range(MAX_LEN):
fc = mx.sym.FullyConnected(data=hidden[t],
weight=cls_weight,
bias=cls_bias,
num_hidden=n_classes)
sm.append(mx.sym.softmax_cross_entropy(fc, labels[t], name="sm"))
for i in range(num_lstm_layer):
state = last_states[i]
state = LSTMState(c=mx.sym.BlockGrad(state.c, name="l%d_last_c" % i),
h=mx.sym.BlockGrad(state.h, name="l%d_last_h" % i))
last_states[i] = state
unpack_c = [state.c for state in last_states]
unpack_h = [state.h for state in last_states]
list_all = sm + unpack_c + unpack_h
return mx.sym.Group(list_all)
评论列表
文章目录