def simplified_data(num_train, num_dev, num_test):
rndstate = random.getstate()
random.seed(0)
trees = loadTrees('train') + loadTrees('dev') + loadTrees('test')
#filter extreme trees
pos_trees = [t for t in trees if t.root.label==4]
neg_trees = [t for t in trees if t.root.label==0]
#binarize labels
binarize_labels(pos_trees)
binarize_labels(neg_trees)
#split into train, dev, test
print len(pos_trees), len(neg_trees)
pos_trees = sorted(pos_trees, key=lambda t: len(t.get_words()))
neg_trees = sorted(neg_trees, key=lambda t: len(t.get_words()))
num_train/=2
num_dev/=2
num_test/=2
train = pos_trees[:num_train] + neg_trees[:num_train]
dev = pos_trees[num_train : num_train+num_dev] + neg_trees[num_train : num_train+num_dev]
test = pos_trees[num_train+num_dev : num_train+num_dev+num_test] + neg_trees[num_train+num_dev : num_train+num_dev+num_test]
random.shuffle(train)
random.shuffle(dev)
random.shuffle(test)
random.setstate(rndstate)
return train, dev, test
评论列表
文章目录