classifier.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:narrative-prediction 作者: roemmele 项目源码 文件源码
def create_model(self):
        model = Sequential()
        model.add(Embedding(output_dim=self.n_embedding_nodes, input_dim=self.lexicon_size + 1,
                            input_length=self.n_timesteps, mask_zero=True, name='embedding_layer'))
        for layer_num in range(self.n_hidden_layers):
            if layer_num == self.n_hidden_layers - 1:
                return_sequences = False
            else: #add extra hidden layers
                return_sequences = True
            model.add(GRU(output_dim=self.n_hidden_nodes, return_sequences=return_sequences, name='hidden_layer' + str(layer_num + 1)))
        model.add(Dense(output_dim=self.n_output_classes, activation='softmax', name='output_layer'))
        # if emb_weights is not None:
        #     #initialize weights with lm weights
        #     model.layers[0].set_weights(emb_weights) #set embeddings
        # if layer1_weights is not None:
        #     model.layers[1].set_weights(layer1_weights) #set recurrent layer 1         
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号