def main():
beam_data = np.load(ARGS.data)
# Optionally load vocabulary data
vocab = None
if ARGS.vocab:
with open(ARGS.vocab) as file:
vocab = file.readlines()
vocab = [_.strip() for _ in vocab]
vocab += ["UNK", "SEQUENCE_START", "SEQUENCE_END"]
if not os.path.exists(ARGS.output_dir):
os.makedirs(ARGS.output_dir)
# Copy required files
shutil.copy2("./bin/tools/beam_search_viz/tree.css", ARGS.output_dir)
shutil.copy2("./bin/tools/beam_search_viz/tree.js", ARGS.output_dir)
for idx in range(len(beam_data["predicted_ids"])):
predicted_ids = beam_data["predicted_ids"][idx]
parent_ids = beam_data["beam_parent_ids"][idx]
scores = beam_data["scores"][idx]
graph = create_graph(
predicted_ids=predicted_ids,
parent_ids=parent_ids,
scores=scores,
vocab=vocab)
json_str = json.dumps(
json_graph.tree_data(graph, (0, 0)),
ensure_ascii=False)
html_str = HTML_TEMPLATE.substitute(DATA=json_str)
output_path = os.path.join(ARGS.output_dir, "{:06d}.html".format(idx))
with open(output_path, "w") as file:
file.write(html_str)
print(output_path)
评论列表
文章目录