gensim_classifier.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:sport-news-retrieval 作者: Andyccs 项目源码 文件源码
def gensim_classifier():
  logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
  label_list = get_labels()
  tweet_list = get_labelled_tweets()

  # split all sentences to list of words
  sentences = []
  for tweet in tweet_list:
    temp_doc = tweet.split()
    sentences.append(temp_doc)

  # parameters for model
  num_features = 100
  min_word_count = 1
  num_workers = 4
  context = 2
  downsampling = 1e-3

  # Initialize and train the model
  w2v_model = Word2Vec(sentences, workers=num_workers, \
              size=num_features, min_count = min_word_count, \
              window = context, sample = downsampling, seed=1)

  index_value, train_set, test_set = train_test_split(0.80, sentences)
  train_vector = getAvgFeatureVecs(train_set, w2v_model, num_features)
  test_vector = getAvgFeatureVecs(test_set, w2v_model, num_features)
  train_vector = Imputer().fit_transform(train_vector)
  test_vector = Imputer().fit_transform(test_vector)

  # train model and predict
  model = LinearSVC()
  classifier_fitted = OneVsRestClassifier(model).fit(train_vector, label_list[:index_value])
  result = classifier_fitted.predict(test_vector)

  # output result to csv
  create_directory('data')
  result.tofile("data/w2v_linsvc.csv", sep=',')

  # store the model to mmap-able files
  create_directory('model')
  joblib.dump(model, 'model/%s.pkl' % 'w2v_linsvc')

  # evaluation
  label_score = classifier_fitted.decision_function(test_vector)
  binarise_result = label_binarize(result, classes=class_list)
  binarise_labels = label_binarize(label_list, classes=class_list)

  evaluate(binarise_result, binarise_labels[index_value:], label_score, 'w2v_linsvc')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号