def _build_sequence_model(self, sequence_input):
RNN = GRU if self._rnn_type == 'gru' else LSTM
def rnn():
rnn = RNN(units=self._rnn_output_size,
return_sequences=True,
dropout=self._dropout_rate,
recurrent_dropout=self._dropout_rate,
kernel_regularizer=self._regularizer,
kernel_initializer=self._initializer,
implementation=2)
rnn = Bidirectional(rnn) if self._bidirectional_rnn else rnn
return rnn
input_ = sequence_input
for _ in range(self._rnn_layers):
input_ = BatchNormalization(axis=-1)(input_)
rnn_out = rnn()(input_)
input_ = rnn_out
time_dist_dense = TimeDistributed(Dense(units=self._vocab_size))(rnn_out)
return time_dist_dense
评论列表
文章目录