learning_model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Auto-correction-for-transliterated-queries 作者: GauravBh1010tt 项目源码 文件源码
def train(self,data):
        print 'building model.....'
        self.prepare_data(data,re_train=True)
        inputs = Input(shape=(self.step,),dtype='int32')
        embed = Embedding(self.vocab_size,self.embedding_dims,input_length=self.step)(inputs)
        encode = LSTM(128)(embed)
        pred = Dense(self.vocab_size,activation='softmax')(encode)
        model = Model(input=inputs,output=pred)
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])
        history = LossHistory()
        model.fit(self.X_data, self.y_data,
                  batch_size=self.batch_size,
                  nb_epoch=self.nb_epoch,callbacks=[history])
        #self.avg_loss = loss.history['loss']
        self.history = history
        with open('history','wb') as h:
            pickle.dump(history.losses,h)

        model_json = model.to_json()
        with open("model.json", "w") as json_file:
            json_file.write(model_json)
        # serialize weights to HDF5
        model.save_weights("model.h5")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号