textgenrnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:textgenrnn 作者: minimaxir 项目源码 文件源码
def textgenrnn_model(weights_path, num_classes, maxlen=40):
    '''
    Builds the model architecture for textgenrnn and
    loads the pretrained weights for the model.
    '''

    input = Input(shape=(maxlen,), name='input')
    embedded = Embedding(num_classes, 100, input_length=maxlen,
                         trainable=True, name='embedding')(input)
    rnn = LSTM(128, return_sequences=False, name='rnn')(embedded)
    output = Dense(num_classes, name='output', activation='softmax')(rnn)

    model = Model(inputs=[input], outputs=[output])
    model.load_weights(weights_path, by_name=True)
    model.compile(loss='categorical_crossentropy', optimizer='nadam')
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号