train_seq2seq.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:website-fingerprinting 作者: AxelGoetz 项目源码 文件源码
def main(_):
    paths, labels = None, None
    dirname, _ = ospath.split(ospath.abspath(__file__))

    try:
        data_dir = dirname + '/../../data/cells'
        paths, labels = import_data(data_dir=data_dir, in_memory=False, extension=args.extension)

        monitored_data, monitored_label, unmonitored_data = split_mon_unmon(paths, labels)
        monitored_data, monitored_label, unmonitored_data = np.array(monitored_data), np.array(monitored_label), np.array(unmonitored_data)

        helpers.shuffle_data(unmonitored_data)
        unmon_train, unmon_test = unmonitored_data[:int((1 - TEST_SIZE) * len(unmonitored_data))], unmonitored_data[int((1 - TEST_SIZE) * len(unmonitored_data)):]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=123)
        sss.get_n_splits(monitored_data, monitored_label)

        for train_index, test_index in sss.split(monitored_data, monitored_label):
            X_train, X_test = monitored_data[train_index], monitored_data[test_index]
            y_train, y_test = monitored_label[train_index], monitored_label[test_index]

            X_train = np.append(X_train, unmon_train)
            X_test = np.append(X_test, unmon_test)

            y_train = np.append(y_train, [-1] * len(unmon_train))
            y_test = np.append(y_test, [-1] * len(unmon_train))

            store_data(X_test, 'X_test')
            store_data(y_test, 'y_test')

            stdout.write("Training on data...\n")
            run_model(X_train, in_memory=False)
            stdout.write("Finished running model.")
            break

    except KeyboardInterrupt:
        stdout.write("Interrupted, this might take a while...\n")
        exit(0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号