model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:baseline 作者: dpressel 项目源码 文件源码
def create(w2v, labels, **kwargs):
        model = ConvModel()
        model.labels = labels
        model.vocab = w2v.vocab
        filtsz = kwargs['filtsz']
        pdrop = kwargs.get('dropout', 0.5)
        mxlen = int(kwargs.get('mxlen', 100))
        cmotsz = kwargs['cmotsz']
        finetune = bool(kwargs.get('finetune', True))
        nc = len(labels)
        x = Input(shape=(mxlen,), dtype='int32', name='input')

        vocab_size = w2v.weights.shape[0]
        embedding_dim = w2v.dsz

        lut = Embedding(input_dim=vocab_size, output_dim=embedding_dim, weights=[w2v.weights], input_length=mxlen, trainable=finetune)

        embed = lut(x)

        mots = []
        for i, fsz in enumerate(filtsz):
            conv = Conv1D(cmotsz, fsz, activation='relu')(embed)
            gmp = GlobalMaxPooling1D()(conv)
            mots.append(gmp)

        joined = merge(mots, mode='concat')
        cmotsz_all = cmotsz * len(filtsz)
        drop1 = Dropout(pdrop)(joined)

        input_dim = cmotsz_all
        last_layer = drop1
        dense = Dense(output_dim=nc, input_dim=input_dim, activation='softmax')(last_layer)
        model.impl = keras.models.Model(input=[x], output=[dense])
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号