def extract_patches(self, all_seqdata, with_targets=True, inplace=False):
print("\tConverting dataset to Numpy array...")
L0_X, L0_Y = list(), list()
for seqdata in all_seqdata:
if not inplace:
seqdata = copy.deepcopy(seqdata)
L = len(seqdata)
tril_indices = np.tril_indices(L, -1) # Lower triangle without diagonal
cpreds = seqdata.cpreds
for cpred in cpreds:
assert(len(cpred.shape) == 2)
if cpred is not None:
cpred_shape = cpred.shape
# Create contact maps from file data
print(len(cpreds))
for i in range(len(cpreds)):
if cpreds[i] is None:
cpreds[i] = np.full(cpred_shape, Params.DISTANCE_NAN, dtype=Params.FLOAT_DTYPE)
else:
cpreds[i][np.isnan(cpreds[i])] = Params.DISTANCE_NAN
cpreds[i] = np.asarray(cpreds[i][tril_indices], dtype=Params.FLOAT_DTYPE)
cpreds = np.asarray(cpreds, dtype=Params.FLOAT_DTYPE).T
L0_X.append(cpreds)
if with_targets:
contacts = distances_to_contacts(seqdata.distances)
contacts = contacts[tril_indices]
L0_Y.append(contacts)
L0_X = np.concatenate(L0_X, axis=0)
if with_targets:
L0_Y = np.concatenate(L0_Y, axis=0)
return (L0_X, L0_Y) if with_targets else L0_X
评论列表
文章目录