def load_sts(dsfile, skip_unlabeled=True):
""" load a dataset in the sts tsv format """
s0 = []
s1 = []
labels = []
with codecs.open(dsfile, encoding='utf8') as f:
for line in f:
line = line.rstrip()
label, s0x, s1x = line.split('\t')
if label == '':
if skip_unlabeled:
continue
else:
labels.append(-1.)
else:
labels.append(float(label))
s0.append(word_tokenize(s0x))
s1.append(word_tokenize(s1x))
return (s0, s1, np.array(labels))
评论列表
文章目录