classifier.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:narrative-prediction 作者: roemmele 项目源码 文件源码
def create_model(self, n_timesteps, batch_size=1, pred_layer=True):

        model = Sequential()

        # if self.embeddings is None:
        model.add(Embedding(self.lexicon_size + 1, self.n_embedding_nodes,
                            batch_input_shape=(batch_size, n_timesteps)))#, mask_zero=True))

        model.add(Reshape((self.n_embedding_nodes * n_timesteps,)))

        for layer_num in range(self.n_hidden_layers):
            model.add(Dense(self.n_hidden_nodes, batch_input_shape=(batch_size, n_timesteps, self.n_embedding_nodes), activation='tanh'))

        if pred_layer: 
            model.add(Dense(self.lexicon_size + 1, activation="softmax"))

        #select optimizer and compile
        model.compile(loss="sparse_categorical_crossentropy", 
                      optimizer=eval(self.optimizer)(clipvalue=self.clipvalue, lr=self.lr, decay=self.decay))

        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号