def reactionrnn_model(weights_path, num_classes, maxlen=140):
'''
Builds the model architecture for textgenrnn and
loads the pretrained weights for the model.
'''
input = Input(shape=(maxlen,), name='input')
embedded = Embedding(num_classes, 100, input_length=maxlen,
name='embedding')(input)
rnn = GRU(256, return_sequences=False, name='rnn')(embedded)
output = Dense(5, name='output',
activation=lambda x: K.relu(x) / K.sum(K.relu(x),
axis=-1))(rnn)
model = Model(inputs=[input], outputs=[output])
model.load_weights(weights_path, by_name=True)
model.compile(loss='mse', optimizer='nadam')
return model
评论列表
文章目录