def qm9_edges(g, e_representation='raw_distance'):
remove_edges = []
e={}
for n1, n2, d in g.edges_iter(data=True):
e_t = []
# Raw distance function
if e_representation == 'chem_graph':
if d['b_type'] is None:
remove_edges += [(n1, n2)]
else:
e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC])
if x == d['b_type']]
elif e_representation == 'distance_bin':
if d['b_type'] is None:
step = (6-2)/8.0
start = 2
b = 9
for i in range(0, 9):
if d['distance'] < (start+i*step):
b = i
break
e_t.append(b+5)
else:
e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC])
if x == d['b_type']]
elif e_representation == 'raw_distance':
if d['b_type'] is None:
remove_edges += [(n1, n2)]
else:
e_t.append(d['distance'])
e_t += [int(d['b_type'] == x) for x in [rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]]
else:
print('Incorrect Edge representation transform')
quit()
if e_t:
e[(n1, n2)] = e_t
for edg in remove_edges:
g.remove_edge(*edg)
return nx.to_numpy_matrix(g), e
评论列表
文章目录