visualize.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pathnet-pytorch 作者: kimhc6028 项目源码 文件源码
def get_fig(self, genes, e_color):

        fixed_pair = [(self.fixed_path[i], self.fixed_path[i+1]) 
                      for i in range(len(self.fixed_path) - 1)]

        for gene in genes:
            gene_pair = [(gene[i], gene[i+1]) for i in range(len(gene) - 1)]

            for layer_num, (pair, fixed) in enumerate(zip(gene_pair, fixed_pair)):
                for first_num in pair[0]:
                    for second_num in pair[1]:
                        first_node = self.node_ids[(layer_num, first_num)]
                        second_node = self.node_ids[(layer_num + 1, second_num)]
                        if self.graph.has_edge(first_node, second_node):
                            self.node_upsize(first_node)
                            self.node_upsize(second_node)
                            weight =  self.graph.get_edge_data(first_node, second_node)['weight']
                            weight += self.edge_weight_add
                            self.graph.add_edge(first_node, second_node, color = e_color, weight = weight)
                        else:
                            self.graph.add_edge(first_node, second_node, color = e_color, weight = self.init_edge_weight)

        for fixed in fixed_pair:
            for f_1 in fixed[0]:
                for f_2 in fixed[1]:
                    if (not f_1 == None) and (not f_2 == None):
                        self.graph.add_edge(f_1, f_2, color = self.fixed_color, weight = self.fixed_weight)

        nodes = self.graph.nodes(data = True)
        node_color = 'g'
        node_size = [node[1]['size'] for node in nodes]
        node_shape = 's'

        edges = self.graph.edges()
        edge_color = [self.graph[u][v]['color'] for u,v in edges]
        weights = [self.graph[u][v]['weight'] for u,v in edges]
        nx.draw_networkx_nodes(self.graph, nodes = nodes, pos=nx.get_node_attributes(self.graph,'Position'), node_color = node_color, node_size = node_size, node_shape = node_shape)
        nx.draw_networkx_edges(self.graph, edges = edges, pos=nx.get_node_attributes(self.graph,'Position'), edge_color = edge_color, width = weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号