reactionrnn.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:reactionrnn 作者: minimaxir 项目源码 文件源码
def reactionrnn_model(weights_path, num_classes, maxlen=140):
    '''
    Builds the model architecture for textgenrnn and
    loads the pretrained weights for the model.
    '''

    input = Input(shape=(maxlen,), name='input')
    embedded = Embedding(num_classes, 100, input_length=maxlen,
                         name='embedding')(input)
    rnn = GRU(256, return_sequences=False, name='rnn')(embedded)
    output = Dense(5, name='output',
                   activation=lambda x: K.relu(x) / K.sum(K.relu(x),
                                                          axis=-1))(rnn)

    model = Model(inputs=[input], outputs=[output])
    model.load_weights(weights_path, by_name=True)
    model.compile(loss='mse', optimizer='nadam')
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号