Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions mdpath/src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from scipy.spatial import cKDTree
from Bio import PDB
from typing import Tuple, List
from multiprocessing import Pool
from tqdm import tqdm
from mdpath.src.structure import StructureCalculations


Expand Down Expand Up @@ -174,3 +176,50 @@ def collect_path_total_weights(self, df_distant_residues: pd.DataFrame) -> list:
except nx.NetworkXNoPath:
continue
return path_total_weights

def calc_path_weight(self, residue_pair: tuple) -> tuple:
"""Calculates the shortest path and total weight for a single residue pair.

Args:
residue_pair (tuple): Tuple of (Residue1, Residue2).

Returns:
tuple | None: (shortest_path, total_weight) or None if no path exists.
"""
res1, res2 = residue_pair
try:
shortest_path, total_weight = self.max_weight_shortest_path(res1, res2)
return (shortest_path, total_weight)
except nx.NetworkXNoPath:
return None

def collect_path_total_weights_parallel(
self, df_distant_residues: pd.DataFrame, num_parallel_processes: int
) -> list:
"""Parallel wrapper to collect the shortest path and total weight between distant residues.

Args:
df_distant_residues (pd.DataFrame): DataFrame with distant residues (columns: Residue1, Residue2).

num_parallel_processes (int): Number of parallel processes.

Returns:
path_total_weights (list): List of tuples with the shortest path and total weight between distant residues.
"""
residue_pairs = [
(row["Residue1"], row["Residue2"])
for _, row in df_distant_residues.iterrows()
]
path_total_weights = []
with Pool(processes=num_parallel_processes) as pool:
with tqdm(
total=len(residue_pairs),
ascii=True,
desc="\033[1mCalculating path total weights\033[0m",
) as pbar:
results = pool.imap_unordered(self.calc_path_weight, residue_pairs)
for result in results:
if result is not None:
path_total_weights.append(result)
pbar.update(1)
return path_total_weights
Loading