def getTestData(test_path,seq_max_len,is_validation = True):
word2id,id2word = loadMap('data/word2id')
label2id,id2label = loadMap('data/label2id')
#print word2id
df_test = pd.read_csv(test_path,delimiter='\t',skip_blank_lines=False,header=None,quoting=csv.QUOTE_NONE,names=['word','label'])
def mapfunc(x):
if str(x) == str(np.nan):
return -1
elif x not in word2id:
return word2id['<NEW>']
else:
return word2id[x]
df_test['word_id'] = df_test.word.map(lambda x : mapfunc(x))
df_test['label_id'] = df_test.label.map(lambda x : -1 if str(x) == str(np.nan) else label2id[x])
if is_validation:
X_test,y_test = prepare(df_test['word_id'],df_test['label_id'],seq_max_len)
return X_test,y_test
else:
df_test['word'] = df_test.word.map(lambda x : -1 if str(x) == str(np.nan) else x)
df_test['label'] = df_test.label.map(lambda x : -1 if str(x) == str(np.nan) else x)
X_test,_ = prepare(df_test['word_id'],df_test['word_id'],seq_max_len)
X_test_str,X_test_label_str = prepare(df_test['word'],df_test['label'],seq_max_len,is_padding=False)
#print X_test_str
return X_test,X_test_str,X_test_label_str
评论列表
文章目录