load_data.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:SwissCheese-at-SemEval-2016 作者: naivechen 项目源码 文件源码
def get_train_data(language):

    # Load data from files
    path = "./data/" + language + "/"
    positive_examples = list(open(path + "rt-polarity.pos", "r").readlines())
    positive_examples = [s.strip() for s in positive_examples[:100]]   # -1000
    negative_examples = list(open(path + "rt-polarity.neg", "r").readlines())
    negative_examples = [s.strip() for s in negative_examples[:100]]

    x_text = positive_examples + negative_examples

    x_text = [sent for sent in x_text]
    # Generate labels
    positive_labels = [[0, 1] for _ in positive_examples]
    negative_labels = [[1, 0] for _ in negative_examples]
    y = np.concatenate([positive_labels, negative_labels], 0)

    # Build vocabulary
    max_length_of_sentence = max([len(jieba.lcut(x)) for x in x_text])
    vocab_processor = learn.preprocessing.VocabularyProcessor(max_length_of_sentence)
    x = np.array(list(vocab_processor.fit_transform(x_text)))

    # Randomly shuffle data
    np.random.seed(1234)
    shuffle_indices = np.random.permutation(np.arange(len(y)))
    x_shuffled = x[shuffle_indices]
    y_shuffled = y[shuffle_indices]

    # Split train/cross-validation set
    cross_validation_indices = np.array(random.sample(np.arange(len(y)), int(len(y) * 0.1) )) 
    train_indices = np.array(list(set(np.arange(len(y))) - set(cross_validation_indices)))

    x_train, x_dev = x_shuffled[train_indices], x_shuffled[cross_validation_indices]
    y_train, y_dev = y_shuffled[train_indices], y_shuffled[cross_validation_indices]

    return [x_train, x_dev, y_train, y_dev, vocab_processor]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号