model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:headline-generation 作者: sallamander 项目源码 文件源码
def make_model(embedding_weights, input_length=50):
    """Build an recurrent net based off the input parameters and return it compiled.

    Args: 
    ----
        embedding_weights: 2d np.ndarray
        input_length (optional): int
            Holds how many words each article body will hold

    Return: 
    ------
        model: keras.model.Sequential compiled model
    """

    dict_size = embedding_weights.shape[0] # Num words in corpus
    embedding_dim = embedding_weights.shape[1] # Num dims in vec representation

    bodies = Input(shape=(input_length,), dtype='int32') 
    embeddings = Embedding(input_dim=dict_size, output_dim=embedding_dim,
                           weights=[embedding_weights], dropout=0.5)(bodies)
    layer = GRU(1024, return_sequences=True, dropout_W=0.5, dropout_U=0.5)(embeddings)
    layer = GRU(1024, return_sequences=False, dropout_W=0.5, dropout_U=0.5)(layer)
    layer = Dense(dict_size, activation='softmax')(layer)

    model = Model(input=bodies, output=layer)

    model.compile(loss='categorical_crossentropy', optimizer='adagrad')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号