def get_vector_representations(sess, model, data, save_dir,
batch_size=100,
max_batches=None,
batches_in_epoch=1000,
max_time_diff=float("inf"),
extension=".cell"):
"""
Given a trained model, gets a vector representation for the traces in batch
@param sess is a tensorflow session
@param model is the seq2seq model
@param data is the data (in batch-major form and not padded or a list of files (depending on `in_memory`))
"""
batches = helpers.get_batches(data, batch_size=batch_size)
batches_in_data = len(data) // batch_size
if max_batches is None or batches_in_data < max_batches:
max_batches = batches_in_data - 1
try:
for batch in range(max_batches):
print("Batch {}/{}".format(batch, max_batches))
fd, paths, _ = model.next_batch(batches, False, max_time_diff)
l = sess.run(model.encoder_final_state, fd)
# Returns a tuple, so we concatenate
if isinstance(l, LSTMStateTuple):
l = np.concatenate((l.c, l.h), axis=1)
file_names = [helpers.extract_filename_from_path(path, extension) for path in paths]
for file_name, features in zip(file_names, list(l)):
helpers.write_to_file(features, save_dir, file_name, new_extension=".cellf")
except KeyboardInterrupt:
stdout.write('Interrupted')
exit(0)
评论列表
文章目录