def sbrt2017(num_hiddens, var_dropout, dropout, weight_decay, num_features=39,
num_classes=28):
""" SBRT model
Reference:
[1] Gal, Y, "A Theoretically Grounded Application of Dropout in
Recurrent Neural Networks", 2015.
[2] Graves, Alex, Abdel-rahman Mohamed, and Geoffrey Hinton. "Speech
recognition with deep recurrent neural networks", 2013.
[6] Wu, Yonghui, et al. "Google's Neural Machine Translation System:
Bridging the Gap between Human and Machine Translation.", 2016.
"""
x = Input(name='inputs', shape=(None, num_features))
o = x
if dropout > 0.0:
o = Dropout(dropout)(o)
o = Bidirectional(LSTM(num_hiddens,
return_sequences=True,
W_regularizer=l2(weight_decay),
U_regularizer=l2(weight_decay),
dropout_W=var_dropout,
dropout_U=var_dropout,
consume_less='gpu'))(o)
if dropout > 0.0:
o = Dropout(dropout)(o)
o = TimeDistributed(Dense(num_classes,
W_regularizer=l2(weight_decay)))(o)
# Define placeholders
labels = Input(name='labels', shape=(None,), dtype='int32', sparse=True)
inputs_length = Input(name='inputs_length', shape=(None,), dtype='int32')
# Define a decoder
dec = Lambda(decode, output_shape=decode_output_shape,
arguments={'is_greedy': True}, name='decoder')
y_pred = dec([o, inputs_length])
ctc = Lambda(ctc_lambda_func, output_shape=(1,), name="ctc")
# Define loss as a layer
loss = ctc([o, labels, inputs_length])
return Model(input=[x, labels, inputs_length], output=[loss, y_pred])
评论列表
文章目录