def process_data(coords, nbr_idx, elements):
num_atoms = len(nbr_idx)
# truncates off zero padding at the end and maps atomic numbers to atom types
coords = coords[:num_atoms, :]
elements = np.array([atom_dictionary[elements[i]] for i in range(num_atoms)], dtype=np.int32)
# pad the neighbor indices with zeros if not enough neighbors
elements = np.append(elements, 0)
for i in range(num_atoms):
if len(nbr_idx[i]) < 12:
nbr_idx[i].extend(np.ones([12-len(nbr_idx[i])], dtype=np.int32) * num_atoms)
nbr_idx = np.array([nbr_idx[i] for i in range(num_atoms)], dtype=np.int32)
# creates neighboring atom type matrix - 0 = nonexistent atom
nbr_atoms = np.take(elements, nbr_idx)
np.place(nbr_idx, nbr_idx >= num_atoms, 0)
elements = elements[:-1]
return (coords.astype(np.float32), nbr_idx.astype(np.int32),
elements.astype(np.int32), nbr_atoms.astype(np.int32))
评论列表
文章目录