sup_parser_v3_hierarchy_cnn.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:conll16st-hd-sdp 作者: tbmihailov 项目源码 文件源码
def filter_items_train_classifier_and_save_model_logreg(classifier_name, class_mapping_curr, relation_type,
                                                                    train_x, train_y_txt,
                                                                    train_y_relation_types, save_model_file):


        """
        Filters items by given params, trains the classifier and saves the word2vec_model to a file.
        Args:
            classifier_name: Name of the classifier used for saving the models
            class_mapping_curr: Class mapping to map train_y_txt to int. Filters items
            relation_type: 1 Explicit, 0 Non Explicit, Filters items with this relation type only
            train_x: Train samples
            train_y_txt: Train sample classes - Text class that will be filtered using class_mapping_curr dict
            train_y_relation_types: Train type indicators if sample is explicit or implicit.
            Only items with relation_type will be used for training
            save_model_file: Name of the file in which the word2vec_model will be saved
        Returns:
            Filters items and trains classifier
        """
        logging.info('======[%s] - filter_items_train_classifier_and_save_model_logreg======' % classifier_name)

        train_x_curr = []
        train_y_curr = []

        # Filtering items
        logging.info('Filtering %s items...' % len(train_x))
        start = time.time()
        for i in range(0, len(train_x)):
            if train_y_txt[i] in class_mapping_curr and train_y_relation_types[i] == relation_type:
                train_x_curr.append(train_x[i])
                train_y_curr.append(class_mapping_curr[train_y_txt[i]])
        end = time.time()
        logging.info("Done in %s s" % (end - start))

        # Training
        # Classifier params
        classifier_current = SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
                                 degree=3, gamma='auto', kernel='rbf',
                                 max_iter=-1, probability=False, random_state=None, shrinking=True,
                                 tol=0.001, verbose=False)
        print 'Classifier:\n%s' % classifier_current

        start = time.time()
        logging.info('Training with %s items...' % len(train_x_curr))
        classifier_current.fit(train_x_curr, train_y_curr)
        end = time.time()
        logging.info("Done in %s s" % (end - start))

        # Saving word2vec_model
        pickle.dump(classifier_current, open(save_model_file, 'wb'))
        logging.info('Model saved to %s' % save_model_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号