seq2seq.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:website-fingerprinting 作者: AxelGoetz 项目源码 文件源码
def get_vector_representations(sess, model, data, save_dir,
                       batch_size=100,
                       max_batches=None,
                       batches_in_epoch=1000,
                       max_time_diff=float("inf"),
                       extension=".cell"):
    """
    Given a trained model, gets a vector representation for the traces in batch

    @param sess is a tensorflow session
    @param model is the seq2seq model
    @param data is the data (in batch-major form and not padded or a list of files (depending on `in_memory`))
    """
    batches = helpers.get_batches(data, batch_size=batch_size)

    batches_in_data = len(data) // batch_size
    if max_batches is None or batches_in_data < max_batches:
        max_batches = batches_in_data - 1

    try:
        for batch in range(max_batches):
            print("Batch {}/{}".format(batch, max_batches))
            fd, paths, _ = model.next_batch(batches, False, max_time_diff)
            l = sess.run(model.encoder_final_state, fd)

            # Returns a tuple, so we concatenate
            if isinstance(l, LSTMStateTuple):
                l = np.concatenate((l.c, l.h), axis=1)

            file_names = [helpers.extract_filename_from_path(path, extension) for path in paths]

            for file_name, features in zip(file_names, list(l)):
                helpers.write_to_file(features, save_dir, file_name, new_extension=".cellf")

    except KeyboardInterrupt:
        stdout.write('Interrupted')
        exit(0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号