Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions mdpath/src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,60 @@
from mdpath.src.structure import StructureCalculations


def _max_weight_shortest_path(graph: nx.Graph, source: int, target: int) -> Tuple:
"""Shortest path between source and target with the maximum total edge weight
among all shortest (fewest-hop) paths. Module-level so it can run in workers
without pickling a GraphBuilder instance.
"""
best = {source: (0, 0)}
heap = [(0, 0, source, [source])]

while heap:
dist, neg_w, u, path = heapq.heappop(heap)
acc_w = -neg_w

if u == target:
return path, acc_w

prev_dist, prev_w = best.get(u, (float("inf"), -float("inf")))
if dist > prev_dist or (dist == prev_dist and acc_w < prev_w):
continue

for v in graph.neighbors(u):
edge_w = graph[u][v].get("weight", 0)
new_dist = dist + 1
new_w = acc_w + edge_w

prev_v = best.get(v, (float("inf"), -float("inf")))
if new_dist < prev_v[0] or (new_dist == prev_v[0] and new_w > prev_v[1]):
best[v] = (new_dist, new_w)
heapq.heappush(heap, (new_dist, -new_w, v, path + [v]))

raise nx.NetworkXNoPath(f"No path between {source} and {target}.")


_WORKER_GRAPH: "nx.Graph | None" = None


def _init_path_worker(graph: nx.Graph) -> None:
"""Pool initializer: stash the residue graph in a module global so workers
reuse it across tasks instead of receiving it as part of every pickled task.
"""
global _WORKER_GRAPH
_WORKER_GRAPH = graph


def _worker_calc_path(residue_pair: tuple):
"""Pool task: compute the max-weight shortest path for a single residue pair
using the worker-local graph set up by _init_path_worker.
"""
res1, res2 = residue_pair
try:
return _max_weight_shortest_path(_WORKER_GRAPH, res1, res2)
except nx.NetworkXNoPath:
return None


class GraphBuilder:
"""Build and analyze residue interaction graphs based on residue distances and mutual information between residue pais.

Expand Down Expand Up @@ -129,31 +183,7 @@ def max_weight_shortest_path(self, source: int, target: int) -> Tuple:

total_weight (float): Total weight of the shortest path.
"""
best = {source: (0, 0)}
heap = [(0, 0, source, [source])]

while heap:
dist, neg_w, u, path = heapq.heappop(heap)
acc_w = -neg_w

if u == target:
return path, acc_w

prev_dist, prev_w = best.get(u, (float("inf"), -float("inf")))
if dist > prev_dist or (dist == prev_dist and acc_w < prev_w):
continue

for v in self.graph.neighbors(u):
edge_w = self.graph[u][v].get("weight", 0)
new_dist = dist + 1
new_w = acc_w + edge_w

prev_v = best.get(v, (float("inf"), -float("inf")))
if new_dist < prev_v[0] or (new_dist == prev_v[0] and new_w > prev_v[1]):
best[v] = (new_dist, new_w)
heapq.heappush(heap, (new_dist, -new_w, v, path + [v]))

raise nx.NetworkXNoPath(f"No path between {source} and {target}.")
return _max_weight_shortest_path(self.graph, source, target)

def collect_path_total_weights(self, df_distant_residues: pd.DataFrame) -> list:
"""Wrapper function to collect the shortest path and total weight between distant residues.
Expand Down Expand Up @@ -211,13 +241,23 @@ def collect_path_total_weights_parallel(
for _, row in df_distant_residues.iterrows()
]
path_total_weights = []
with Pool(processes=num_parallel_processes) as pool:
# Change
# Pickle the graph once per worker via initializer/initargs instead of
# once per task (which is what happens when a bound method is dispatched).
chunksize = max(1, len(residue_pairs) // (num_parallel_processes * 4)) if residue_pairs else 1
with Pool(
processes=num_parallel_processes,
initializer=_init_path_worker,
initargs=(self.graph,),
) as pool:
with tqdm(
total=len(residue_pairs),
ascii=True,
desc="\033[1mCalculating path total weights\033[0m",
) as pbar:
results = pool.imap_unordered(self.calc_path_weight, residue_pairs)
results = pool.imap_unordered(
_worker_calc_path, residue_pairs, chunksize=chunksize
)
for result in results:
if result is not None:
path_total_weights.append(result)
Expand Down
32 changes: 32 additions & 0 deletions mdpath/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,38 @@ def test_collect_path_total_weights():
assert result == case["expected_result"]


def test_collect_path_total_weights_parallel():
with (
patch("mdpath.src.graph.StructureCalculations"),
patch("mdpath.src.graph.PDB.PDBParser"),
):
G = nx.Graph()
G.add_edge(1, 2, weight=1.0)
G.add_edge(2, 3, weight=2.0)
G.add_edge(1, 3, weight=4.0)
G.add_edge(3, 4, weight=1.0)
G.add_edge(2, 4, weight=3.0)

graph_builder = GraphBuilder(
pdb="", last_residue=0, mi_diff_df=pd.DataFrame(), graphdist=5
)
graph_builder.graph = G

df = pd.DataFrame(
{"Residue1": [1, 2, 1], "Residue2": [4, 4, 99]}
)

serial_result = graph_builder.collect_path_total_weights(df)
parallel_result = graph_builder.collect_path_total_weights_parallel(
df, num_parallel_processes=2
)

normalize = lambda r: sorted(
(tuple(path), round(w, 9)) for path, w in r
)
assert normalize(parallel_result) == normalize(serial_result)


def test_graph_skeleton():
"""Test the graph_skeleton method using actual data from mi_diff_df.csv and first_frame.pdb."""

Expand Down
Loading