def onehot_gene(DB, tr, te):
from utils.np_utils.encoder import onehot_encode
if tr is None:
train = DB.data['training_variants']
if te=="stage1":
test = DB.data['test_variants_filter']
else:
train = pd.concat([train,DB.data['test_variants_filter']],axis=0)
test = DB.data['stage2_test_variants']
lbl_encode(train,test)
n = max(train['Gene'].max(),test['Gene'].max())
gtr = onehot_encode(train['Gene'].values,n=n+1)
gte = onehot_encode(test['Gene'].values)
return gtr,gte
else:
data = DB.data['training_variants']
lbl_encode(data,cols=['Gene'])
gene = data['Gene'].values
gene = onehot_encode(gene)
return gene[tr],gene[te]
评论列表
文章目录