def getCytoRNADataFromCsv(dataPath, batchesPath, batch1, batch2, trainPct = 0.8):
data = genfromtxt(dataPath, delimiter=',', skip_header=0)
batches = genfromtxt(batchesPath, delimiter=',', skip_header=0)
source = data[batches == batch1]
target = data[batches == batch2]
n_source = source.shape[0]
p = np.random.permutation(n_source)
cutPt = int(n_source * trainPct)
source_train = source[p[:cutPt]]
source_test = source[p[cutPt:]]
n_target = target.shape[0]
p = np.random.permutation(n_target)
cutPt = int(n_target * trainPct)
target_train = target[p[:cutPt]]
target_test = target[p[cutPt:]]
return source_train, source_test, target_train, target_test
评论列表
文章目录