def textgenrnn_model(weights_path, num_classes, maxlen=40):
'''
Builds the model architecture for textgenrnn and
loads the pretrained weights for the model.
'''
input = Input(shape=(maxlen,), name='input')
embedded = Embedding(num_classes, 100, input_length=maxlen,
trainable=True, name='embedding')(input)
rnn = LSTM(128, return_sequences=False, name='rnn')(embedded)
output = Dense(num_classes, name='output', activation='softmax')(rnn)
model = Model(inputs=[input], outputs=[output])
model.load_weights(weights_path, by_name=True)
model.compile(loss='categorical_crossentropy', optimizer='nadam')
return model
评论列表
文章目录