models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:cervantes 作者: textclf 项目源码 文件源码
def _generate_model(self, lembedding, num_classes=2, unit='gru', rnn_size=128, train_vectors=True):

        input = Input(shape=(lembedding.size,), dtype='int32')
        if lembedding.vector_box.W is None:
            emb = Embedding(lembedding.vector_box.size,
                            lembedding.vector_box.vector_dim,
                            W_constraint=None)(input)
        else:
            emb = Embedding(lembedding.vector_box.size,
                            lembedding.vector_box.vector_dim,
                            weights=[lembedding.vector_box.W], W_constraint=None, )(input)
        emb.trainable = train_vectors
        if unit == 'gru':
            forward = GRU(rnn_size)(emb)
            backward = GRU(rnn_size, go_backwards=True)(emb)
        else:
            forward = LSTM(rnn_size)(emb)
            backward = LSTM(rnn_size, go_backwards=True)(emb)

        merged_rnn = merge([forward, backward], mode='concat')
        dropped = Dropout(0.5)(merged_rnn)
        if num_classes == 2:
            out = Dense(1, activation='sigmoid')(dropped)
            model = Model(input=input, output=out)
            if self.optimizer is None:
                self.optimizer = 'rmsprop'
            model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=["accuracy"])
        else:
            out = Dense(num_classes, activation='softmax')(dropped)
            model = Model(input=input, output=out)
            if self.optimizer is None:
                self.optimizer = 'adam'
            model.compile(loss='categorical_crossentropy', optimizer=self.optimizer, metrics=["accuracy"])

        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号