diff --git a/mdpath/src/graph.py b/mdpath/src/graph.py index a0f5b8b..ba51db5 100644 --- a/mdpath/src/graph.py +++ b/mdpath/src/graph.py @@ -18,6 +18,8 @@ from scipy.spatial import cKDTree from Bio import PDB from typing import Tuple, List +from multiprocessing import Pool +from tqdm import tqdm from mdpath.src.structure import StructureCalculations @@ -174,3 +176,50 @@ def collect_path_total_weights(self, df_distant_residues: pd.DataFrame) -> list: except nx.NetworkXNoPath: continue return path_total_weights + + def calc_path_weight(self, residue_pair: tuple) -> tuple: + """Calculates the shortest path and total weight for a single residue pair. + + Args: + residue_pair (tuple): Tuple of (Residue1, Residue2). + + Returns: + tuple | None: (shortest_path, total_weight) or None if no path exists. + """ + res1, res2 = residue_pair + try: + shortest_path, total_weight = self.max_weight_shortest_path(res1, res2) + return (shortest_path, total_weight) + except nx.NetworkXNoPath: + return None + + def collect_path_total_weights_parallel( + self, df_distant_residues: pd.DataFrame, num_parallel_processes: int + ) -> list: + """Parallel wrapper to collect the shortest path and total weight between distant residues. + + Args: + df_distant_residues (pd.DataFrame): DataFrame with distant residues (columns: Residue1, Residue2). + + num_parallel_processes (int): Number of parallel processes. + + Returns: + path_total_weights (list): List of tuples with the shortest path and total weight between distant residues. + """ + residue_pairs = [ + (row["Residue1"], row["Residue2"]) + for _, row in df_distant_residues.iterrows() + ] + path_total_weights = [] + with Pool(processes=num_parallel_processes) as pool: + with tqdm( + total=len(residue_pairs), + ascii=True, + desc="\033[1mCalculating path total weights\033[0m", + ) as pbar: + results = pool.imap_unordered(self.calc_path_weight, residue_pairs) + for result in results: + if result is not None: + path_total_weights.append(result) + pbar.update(1) + return path_total_weights