diff --git a/mdpath/src/graph.py b/mdpath/src/graph.py index ba51db5..fffc499 100644 --- a/mdpath/src/graph.py +++ b/mdpath/src/graph.py @@ -23,6 +23,60 @@ from mdpath.src.structure import StructureCalculations +def _max_weight_shortest_path(graph: nx.Graph, source: int, target: int) -> Tuple: + """Shortest path between source and target with the maximum total edge weight + among all shortest (fewest-hop) paths. Module-level so it can run in workers + without pickling a GraphBuilder instance. + """ + best = {source: (0, 0)} + heap = [(0, 0, source, [source])] + + while heap: + dist, neg_w, u, path = heapq.heappop(heap) + acc_w = -neg_w + + if u == target: + return path, acc_w + + prev_dist, prev_w = best.get(u, (float("inf"), -float("inf"))) + if dist > prev_dist or (dist == prev_dist and acc_w < prev_w): + continue + + for v in graph.neighbors(u): + edge_w = graph[u][v].get("weight", 0) + new_dist = dist + 1 + new_w = acc_w + edge_w + + prev_v = best.get(v, (float("inf"), -float("inf"))) + if new_dist < prev_v[0] or (new_dist == prev_v[0] and new_w > prev_v[1]): + best[v] = (new_dist, new_w) + heapq.heappush(heap, (new_dist, -new_w, v, path + [v])) + + raise nx.NetworkXNoPath(f"No path between {source} and {target}.") + + +_WORKER_GRAPH: "nx.Graph | None" = None + + +def _init_path_worker(graph: nx.Graph) -> None: + """Pool initializer: stash the residue graph in a module global so workers + reuse it across tasks instead of receiving it as part of every pickled task. + """ + global _WORKER_GRAPH + _WORKER_GRAPH = graph + + +def _worker_calc_path(residue_pair: tuple): + """Pool task: compute the max-weight shortest path for a single residue pair + using the worker-local graph set up by _init_path_worker. + """ + res1, res2 = residue_pair + try: + return _max_weight_shortest_path(_WORKER_GRAPH, res1, res2) + except nx.NetworkXNoPath: + return None + + class GraphBuilder: """Build and analyze residue interaction graphs based on residue distances and mutual information between residue pais. @@ -129,31 +183,7 @@ def max_weight_shortest_path(self, source: int, target: int) -> Tuple: total_weight (float): Total weight of the shortest path. """ - best = {source: (0, 0)} - heap = [(0, 0, source, [source])] - - while heap: - dist, neg_w, u, path = heapq.heappop(heap) - acc_w = -neg_w - - if u == target: - return path, acc_w - - prev_dist, prev_w = best.get(u, (float("inf"), -float("inf"))) - if dist > prev_dist or (dist == prev_dist and acc_w < prev_w): - continue - - for v in self.graph.neighbors(u): - edge_w = self.graph[u][v].get("weight", 0) - new_dist = dist + 1 - new_w = acc_w + edge_w - - prev_v = best.get(v, (float("inf"), -float("inf"))) - if new_dist < prev_v[0] or (new_dist == prev_v[0] and new_w > prev_v[1]): - best[v] = (new_dist, new_w) - heapq.heappush(heap, (new_dist, -new_w, v, path + [v])) - - raise nx.NetworkXNoPath(f"No path between {source} and {target}.") + return _max_weight_shortest_path(self.graph, source, target) def collect_path_total_weights(self, df_distant_residues: pd.DataFrame) -> list: """Wrapper function to collect the shortest path and total weight between distant residues. @@ -211,13 +241,23 @@ def collect_path_total_weights_parallel( for _, row in df_distant_residues.iterrows() ] path_total_weights = [] - with Pool(processes=num_parallel_processes) as pool: + # Change + # Pickle the graph once per worker via initializer/initargs instead of + # once per task (which is what happens when a bound method is dispatched). + chunksize = max(1, len(residue_pairs) // (num_parallel_processes * 4)) if residue_pairs else 1 + with Pool( + processes=num_parallel_processes, + initializer=_init_path_worker, + initargs=(self.graph,), + ) as pool: with tqdm( total=len(residue_pairs), ascii=True, desc="\033[1mCalculating path total weights\033[0m", ) as pbar: - results = pool.imap_unordered(self.calc_path_weight, residue_pairs) + results = pool.imap_unordered( + _worker_calc_path, residue_pairs, chunksize=chunksize + ) for result in results: if result is not None: path_total_weights.append(result) diff --git a/mdpath/tests/test_graph.py b/mdpath/tests/test_graph.py index 25491ee..1cf394c 100644 --- a/mdpath/tests/test_graph.py +++ b/mdpath/tests/test_graph.py @@ -165,6 +165,38 @@ def test_collect_path_total_weights(): assert result == case["expected_result"] +def test_collect_path_total_weights_parallel(): + with ( + patch("mdpath.src.graph.StructureCalculations"), + patch("mdpath.src.graph.PDB.PDBParser"), + ): + G = nx.Graph() + G.add_edge(1, 2, weight=1.0) + G.add_edge(2, 3, weight=2.0) + G.add_edge(1, 3, weight=4.0) + G.add_edge(3, 4, weight=1.0) + G.add_edge(2, 4, weight=3.0) + + graph_builder = GraphBuilder( + pdb="", last_residue=0, mi_diff_df=pd.DataFrame(), graphdist=5 + ) + graph_builder.graph = G + + df = pd.DataFrame( + {"Residue1": [1, 2, 1], "Residue2": [4, 4, 99]} + ) + + serial_result = graph_builder.collect_path_total_weights(df) + parallel_result = graph_builder.collect_path_total_weights_parallel( + df, num_parallel_processes=2 + ) + + normalize = lambda r: sorted( + (tuple(path), round(w, 9)) for path, w in r + ) + assert normalize(parallel_result) == normalize(serial_result) + + def test_graph_skeleton(): """Test the graph_skeleton method using actual data from mi_diff_df.csv and first_frame.pdb."""