def _create_viz_dag(comp_dag, colors='state', cmap=None):
colors = colors.lower()
if colors == 'state':
if cmap is None:
cmap = _state_colors
elif colors == 'timing':
if cmap is None:
cmap = mpl.colors.LinearSegmentedColormap.from_list('blend', ['#15b01a', '#ffff14', '#e50000'])
timings = nx.get_node_attributes(comp_dag, _AN_TIMING)
max_duration = max(timing.duration for timing in six.itervalues(timings) if hasattr(timing, 'duration'))
min_duration = min(timing.duration for timing in six.itervalues(timings) if hasattr(timing, 'duration'))
else:
raise ValueError('{} is not a valid loman colors parameter for visualization'.format(colors))
viz_dag = nx.DiGraph()
node_index_map = {}
for i, (name, data) in enumerate(comp_dag.nodes(data=True)):
short_name = "n{}".format(i)
attr_dict = {
'label': name,
'style': 'filled',
'_group': data.get(_AN_GROUP)
}
if colors == 'state':
attr_dict['fillcolor'] = cmap[data.get(_AN_STATE, None)]
elif colors == 'timing':
timing_data = data.get(_AN_TIMING)
if timing_data is None:
col = '#FFFFFF'
else:
duration = timing_data.duration
norm_duration = (duration - min_duration) / (max_duration - min_duration)
col = mpl.colors.rgb2hex(cmap(norm_duration))
attr_dict['fillcolor'] = col
viz_dag.add_node(short_name, **attr_dict)
node_index_map[name] = short_name
for name1, name2 in comp_dag.edges():
short_name_1 = node_index_map[name1]
short_name_2 = node_index_map[name2]
group1 = comp_dag.node[name1].get(_AN_GROUP)
group2 = comp_dag.node[name2].get(_AN_GROUP)
group = group1 if group1 == group2 else None
attr_dict = {'_group': group}
viz_dag.add_edge(short_name_1, short_name_2, **attr_dict)
return viz_dag
评论列表
文章目录