def main(_):
paths, labels = None, None
dirname, _ = ospath.split(ospath.abspath(__file__))
try:
data_dir = dirname + '/../../data/cells'
paths, labels = import_data(data_dir=data_dir, in_memory=False, extension=args.extension)
monitored_data, monitored_label, unmonitored_data = split_mon_unmon(paths, labels)
monitored_data, monitored_label, unmonitored_data = np.array(monitored_data), np.array(monitored_label), np.array(unmonitored_data)
helpers.shuffle_data(unmonitored_data)
unmon_train, unmon_test = unmonitored_data[:int((1 - TEST_SIZE) * len(unmonitored_data))], unmonitored_data[int((1 - TEST_SIZE) * len(unmonitored_data)):]
sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=123)
sss.get_n_splits(monitored_data, monitored_label)
for train_index, test_index in sss.split(monitored_data, monitored_label):
X_train, X_test = monitored_data[train_index], monitored_data[test_index]
y_train, y_test = monitored_label[train_index], monitored_label[test_index]
X_train = np.append(X_train, unmon_train)
X_test = np.append(X_test, unmon_test)
y_train = np.append(y_train, [-1] * len(unmon_train))
y_test = np.append(y_test, [-1] * len(unmon_train))
store_data(X_test, 'X_test')
store_data(y_test, 'y_test')
stdout.write("Training on data...\n")
run_model(X_train, in_memory=False)
stdout.write("Finished running model.")
break
except KeyboardInterrupt:
stdout.write("Interrupted, this might take a while...\n")
exit(0)
train_autoencoder.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录