SKLProcessors.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Ossian 作者: CSTR-Edinburgh 项目源码 文件源码
def do_training(self, speech_corpus, text_corpus):

        if self.model:  ## if already trained...
            return

        ## 1) get data:
        #### [Added dump_features method to Utterance class, use that: ]
        x_data = []
        y_data = []
        for utterance in speech_corpus:

            utt_feats = utterance.dump_features(self.target_nodes, \
                                                self.context_list, return_dict=True)

            for example in utt_feats:
                assert 'response' in example,example
                y_data.append({'response': example['response']})
                del example['response']
                x_data.append(example)

        ## Handle categorical features (strings) but to keep numerical ones 
        ## as they are:

        x_vectoriser = DictVectorizer()
        x_data = x_vectoriser.fit_transform(x_data).toarray()

        y_vectoriser = DictVectorizer()
        y_data = y_vectoriser.fit_transform(y_data).toarray()

        if False:
            print x_data
            print y_data

        ## 2) train classifier:
        model = tree.DecisionTreeClassifier(min_samples_leaf=self.min_samples_leaf)

        model.fit(x_data, y_data) 
        print '\n Trained classifier: '
        print model
        print '\n Trained x vectoriser:'
        print x_vectoriser
        print 'Feature names:'
        print x_vectoriser.get_feature_names()
        print '\n Trained y vectoriser:'
        print y_vectoriser
        print 'Feature names:'
        print y_vectoriser.get_feature_names()

        ## 3) Save classifier by pickling:
        output = open(self.model_file, 'wb')
        pickle.dump([x_vectoriser, y_vectoriser, model], output)
        output.close()        

        ## Write ASCII tree representation (which can be plotted):
        tree.export_graphviz(model, out_file=self.model_file + '.dot',  \
                                     feature_names=x_vectoriser.get_feature_names())

        self.verify(self.voice_resources) # ## reload -- get self.model etc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号