From 9490ff887bfd79a5962e44b1bea43d6d762ad86f Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sat, 13 Dec 2025 04:32:24 +0000 Subject: [PATCH 01/28] add normal matvec and memory profiler --- ba_example.py | 36 +++++- bae/optim/optimizer.py | 26 ++-- bae/sparse/py_ops.py | 25 ++-- bae/utils/__init__.py | 1 + bae/utils/linear_operator.py | 163 ++++++++++++++++++++++++ bae/utils/pysolvers.py | 230 +--------------------------------- tests/test_normal_operator.py | 87 +++++++++++++ 7 files changed, 317 insertions(+), 251 deletions(-) create mode 100644 bae/utils/linear_operator.py create mode 100644 tests/test_normal_operator.py diff --git a/ba_example.py b/ba_example.py index 550c25b..dec3758 100644 --- a/ba_example.py +++ b/ba_example.py @@ -1,4 +1,6 @@ from time import perf_counter +from pathlib import Path +from datetime import datetime import torch import pypose as pp @@ -9,13 +11,13 @@ from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS -# TARGET_DATASET = "ladybug" -# TARGET_PROBLEM = "problem-1723-156502-pre" +TARGET_DATASET = "ladybug" +TARGET_PROBLEM = "problem-1723-156502-pre" # TARGET_PROBLEM = "problem-49-7776-pre" # TARGET_PROBLEM = "problem-1695-155710-pre" # TARGET_PROBLEM = "problem-969-105826-pre" -TARGET_DATASET = "trafalgar" -TARGET_PROBLEM = "problem-257-65132-pre" +# TARGET_DATASET = "trafalgar" +# TARGET_PROBLEM = "problem-257-65132-pre" # TARGET_DATASET = "dubrovnik" # TARGET_PROBLEM = "problem-356-226730-pre" @@ -28,6 +30,21 @@ file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) +memory_snapshot_path = None + +if DEVICE.startswith("cuda") and torch.cuda.is_available(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + snapshot_dir = Path("memory_traces") + snapshot_dir.mkdir(exist_ok=True) + memory_snapshot_path = snapshot_dir / f"{file_name}_cuda_memory_{timestamp}.pickle" + # Record allocator events so we can inspect GPU memory usage after the run. + torch.cuda.memory._record_memory_history( + enabled="all", + context="all", + stacks="python", + device=torch.device(DEVICE), + clear_history=True, + ) if OPTIMIZE_INTRINSICS: NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9 @@ -51,7 +68,9 @@ ).to(DEVICE) strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) # or CuDSS() -optimizer = LM(model, strategy=strategy, solver=solver, reject=30) +optimizer = LM(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) + + print('Loss:', least_square_error( model.pose, @@ -68,6 +87,13 @@ loss = optimizer.step(input) print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) +if memory_snapshot_path: + torch.cuda.synchronize() + torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) + print(f"CUDA memory snapshot saved to {memory_snapshot_path}") + +# exit() + torch.cuda.synchronize() end = perf_counter() print('Time', end - start) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index cc83d88..1bc03ba 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -7,6 +7,7 @@ from ..autograd.function import TrackingTensor from ..sparse.py_ops import diagonal_op_ from ..sparse.spgemm import CuSparse +from ..utils.linear_operator import NormalMatVec def jacobian(output, params): assert output.optrace[id(output)][0] == 'map', "The last operation in compute graph being indexing transform is not meaningful" @@ -49,7 +50,8 @@ def jacobian(output, params): class LM(ppLM): - def __init__(self, *args, **kwargs): + def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): + self.matrix_free_normal = matrix_free_normal super(LM, self).__init__(*args, **kwargs) self.mm = CuSparse() @@ -59,24 +61,34 @@ def step(self, input, target=None, weight=None): weight = self.weight if weight is None else weight R = list(self.model(input)) R = R[0] - J = jacobian(R, pg['params']) + J_list = jacobian(R, pg['params']) if isinstance(R, TrackingTensor): R = R.tensor() - J = torch.cat([j.to_sparse_coo() for j in J], dim=-1) + J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1) self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) J_T = J.mT self.reject_count = 0 J_T = J_T.to_sparse_csr() J = J.to_sparse_csr() - A = self.mm(J_T, J) + rhs = -J_T @ R.view(-1, 1) - diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + if self.matrix_free_normal: + diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) + A = NormalMatVec(J, damping=0.0, diag=diag) + diag_scale = 1.0 + else: + A = self.mm(J_T, J) + diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) while self.last <= self.loss: - diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) + if self.matrix_free_normal: + diag_scale *= 1.0 + pg['damping'] + A.set_damping(diag_scale - 1.0) + else: + diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) try: - D = self.solver(A, -J_T @ R.view(-1, 1)) + D = self.solver(A, rhs) D = D[:, None] except Exception as e: print(e, "\nLinear solver failed. Breaking optimization step...") diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 19164f3..464241e 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -232,16 +232,15 @@ def bsr2bsc(J): sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCPU') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCUDA') -if True: - crow_indices = torch.tensor([0, 2, 4]) - col_indices = torch.tensor([0, 1, 0, 1]) - values = torch.tensor([[[0, 1, 2], [6, 7, 8]], - [[3, 4, 5], [9, 10, 11]], - [[12, 13, 14], [18, 19, 20]], - [[15, 16, 17], [21, 22, 23]]]) - bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) - bsr = bsr.to('cuda') - csr = bsr.to_sparse_coo().to_sparse_csr() - # print(csr) - output = diagonal_op_triton_(csr) - # print(output) \ No newline at end of file +if __name__ == "__main__": + if torch.cuda.is_available(): + crow_indices = torch.tensor([0, 2, 4]) + col_indices = torch.tensor([0, 1, 0, 1]) + values = torch.tensor([[[0, 1, 2], [6, 7, 8]], + [[3, 4, 5], [9, 10, 11]], + [[12, 13, 14], [18, 19, 20]], + [[15, 16, 17], [21, 22, 23]]]) + bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) + bsr = bsr.to('cuda') + csr = bsr.to_sparse_coo().to_sparse_csr() + output = diagonal_op_triton_(csr) diff --git a/bae/utils/__init__.py b/bae/utils/__init__.py index e69de29..48ea6d3 100644 --- a/bae/utils/__init__.py +++ b/bae/utils/__init__.py @@ -0,0 +1 @@ +from .linear_operator import NormalMatVec diff --git a/bae/utils/linear_operator.py b/bae/utils/linear_operator.py new file mode 100644 index 0000000..070badf --- /dev/null +++ b/bae/utils/linear_operator.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Optional, Union, Any, Tuple + +import torch +from torch import Tensor + + +class NormalMatVec: + r"""Matrix-free normal-equation linear operator. + + Given a Jacobian J (dense, CSR/COO, or BSR), this operator represents: + + A x = J^T (J x) + damping * diag(J^T J) * x + + where diag(J^T J) is computed as the column-wise sum of squares of J. + """ + + def __init__( + self, + J: Tensor, + damping: Union[float, Tensor] = 0.0, + diag: Optional[Tensor] = None, + ): + if not torch.is_tensor(J): + raise TypeError("J must be a torch.Tensor") + if J.ndim != 2: + raise ValueError("J must be 2-D") + + self.J: Tensor = J + self._Jt: Optional[Tensor] = None + + self.device = J.device + self.dtype = J.dtype + self.layout = J.layout if J.layout != torch.sparse_coo else torch.sparse_csr + + self.shape: Tuple[int, int] = (J.shape[1], J.shape[1]) + self.ndim: int = 2 + + self._diag: Tensor = diag if diag is not None else self._compute_diag(J) + if self._diag.ndim != 1 or self._diag.numel() != J.shape[1]: + raise ValueError("diag must be 1-D with length equal to J.shape[1]") + + self.set_damping(damping) + + def set_damping(self, damping: Union[float, Tensor]) -> None: + if isinstance(damping, Tensor): + if damping.numel() != 1: + raise ValueError("damping tensor must be scalar") + self.damping = damping.to(device=self.device, dtype=self.dtype) + else: + self.damping = float(damping) + + def diagonal(self) -> Tensor: + damp = self._damping_value() + if damp is None: + return self._diag + return self._diag * (1.0 + damp) + + def matvec(self, x: Tensor) -> Tensor: + if x.ndim == 1: + x2d = x.unsqueeze(-1) + elif x.ndim == 2: + x2d = x + else: + raise ValueError("x must be 1-D or 2-D") + + if x2d.device != self.device or x2d.dtype != self.dtype: + x2d = x2d.to(device=self.device, dtype=self.dtype) + + y = self.J @ x2d + Jt = self._get_Jt() + z = Jt @ y + + damp = self._damping_value() + if damp is not None: + z = z + damp * self._diag.unsqueeze(-1) * x2d + + return z.squeeze(-1) if x.ndim == 1 else z + + def __call__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + def __matmul__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + @classmethod + def __torch_function__( # type: ignore[override] + cls, func: Any, types: Any, args: Tuple[Any, ...] = (), kwargs: Optional[dict] = None + ): + if kwargs is None: + kwargs = {} + if func is torch.matmul and len(args) >= 2: + A, B = args[0], args[1] + if isinstance(A, cls): + out = kwargs.get("out", None) + result = A.matvec(B) + if out is not None: + out_tensor = out[0] if isinstance(out, tuple) else out + out_tensor.copy_(result) + return out_tensor + return result + return NotImplemented + + def _get_Jt(self) -> Tensor: + if self._Jt is None: + Jt = self.J.mT + if Jt.layout == torch.sparse_csc: + Jt = Jt.to_sparse_csr() + elif Jt.layout == torch.sparse_bsc: + bs = Jt.values().shape[-2:] + Jt = Jt.to_sparse_bsr(blocksize=bs) + self._Jt = Jt + return self._Jt + + def _damping_value(self) -> Optional[Tensor]: + if isinstance(self.damping, Tensor): + if self.damping.item() == 0.0: + return None + return self.damping + if self.damping == 0.0: + return None + return torch.tensor(self.damping, device=self.device, dtype=self.dtype) + + @staticmethod + def _compute_diag(J: Tensor) -> Tensor: + if J.layout == torch.strided: + return J.square().sum(dim=0) + + if J.layout == torch.sparse_bsr: + values = J.values() + dm, dn = values.shape[-2], values.shape[-1] + col_blocks = J.col_indices() + contrib = values.square().sum(dim=-2) # (nnz_blocks, dn) + offsets = torch.arange(dn, device=contrib.device, dtype=col_blocks.dtype) + cols = (col_blocks[:, None] * dn + offsets[None, :]).reshape(-1).to(torch.int64) + contrib_flat = contrib.reshape(-1) + diag = torch.zeros(J.shape[1], device=contrib.device, dtype=contrib.dtype) + diag.scatter_add_(0, cols, contrib_flat) + return diag + + if J.layout == torch.sparse_csr: + values = J.values() + col = J.col_indices().to(torch.int64) + v2 = values.square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=values.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + if J.layout == torch.sparse_coo: + Jc = J.coalesce() + col = Jc.indices()[1].to(torch.int64) + v2 = Jc.values().square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=Jc.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + raise NotImplementedError(f"Unsupported J layout: {J.layout}") + diff --git a/bae/utils/pysolvers.py b/bae/utils/pysolvers.py index 772f6ab..f0205f1 100644 --- a/bae/utils/pysolvers.py +++ b/bae/utils/pysolvers.py @@ -13,14 +13,15 @@ def __init__(self, maxiter=None, tol=1e-5): def forward(self, A, b, x=None, M=None) -> torch.Tensor: if b.dim() == 1: b = b[..., None] - l_diag = A.diagonal() + l_diag = A.diagonal().clone() l_diag[l_diag.abs() < 1e-6] = 1e-6 M = spdiags_((1 / l_diag), None, shape=A.shape, layout=None) - if A.layout == torch.sparse_csr: + layout = getattr(A, "layout", torch.strided) + if layout == torch.sparse_csr: # M = M.to_sparse_csr() pass # A = M @ A - elif A.layout == torch.sparse_bsr: + elif layout == torch.sparse_bsr and isinstance(A, torch.Tensor): M = M.to_sparse_bsr(blocksize=A.values().shape[-2:]).to(A.device) # A = M @ A.to_sparse_bsc(blocksize=A.values().shape[-2:]) # b = M @ b @@ -50,226 +51,3 @@ def forward(self, A, b): # print(f"Linear Solver Error: {a_err}, relative error: {r_err}") return torch.from_numpy(x).to(A.device) - -# cuda graph version of the solver -class CG_(torch.nn.Module): - r'''The batched linear solver with conjugate gradient method. - - .. math:: - \mathbf{A}_i \bm{x}_i = \mathbf{b}_i, - - where :math:`\mathbf{A}_i \in \mathbb{C}^{M \times N}` and :math:`\bm{b}_i \in - \mathbb{C}^{M \times 1}` are the :math:`i`-th item of batched linear equations. - - This function is a 1:1 replica of `scipy.sparse.linalg.cg `_. - The solution is consistent with the scipy version up to numerical precision. - Variable names are kept the same as the scipy version for easy reference. - We recommend using only non-batched or batch size 1 input for this solver, as - the batched version was not appeared in the original scipy version. When handling - sparse matrices, the batched computation may introduce additional overhead. - - Examples: - >>> # dense example - >>> import pypose.optim.solver as ppos - >>> A = torch.tensor([[0.1802967, 0.3151198, 0.4548111, 0.3860016, 0.2870615], - [0.3151198, 1.4575327, 1.5533425, 1.0540756, 1.0795838], - [0.4548111, 1.5533425, 2.3674474, 1.1222278, 1.2365348], - [0.3860016, 1.0540756, 1.1222278, 1.3748058, 1.2223261], - [0.2870615, 1.0795838, 1.2365348, 1.2223261, 1.2577004]]) - >>> b = torch.tensor([[ 2.64306851], - [-0.03593633], - [ 0.73612658], - [ 0.51501254], - [-0.26689271]]) - >>> solver = ppos.CG() - >>> x = solver(A, b) - tensor([[246.4098], - [ 22.6997], - [-56.9239], - [-161.7914], - [137.2683]]) - - >>> # sparse csr example - >>> import pypose.optim.solver as ppos - >>> crow_indices = torch.tensor([0, 2, 4]) - >>> col_indices = torch.tensor([0, 1, 0, 1]) - >>> values = torch.tensor([1, 2, 3, 4], dtype=torch.float) - >>> A = torch.sparse_csr_tensor(crow_indices, col_indices, values) - >>> A.to_dense() # visualize - tensor([[1., 2.], - [3., 4.]]) - >>> b = torch.tensor([[1.], [2.]]) - >>> solver = ppos.CG() - >>> x = solver(A, b) - tensor([-4.4052e-05, 5.0003e-01]) - - ''' - def __init__(self, maxiter=None, tol=1e-5): - super().__init__() - self.maxiter, self.tol = maxiter, tol - self.graph_first_iter = None - self.graph_subsequent_iter = None - self.static_A_shape, self.static_b_shape, self.static_M_is_none, self.static_device = \ - None, None, None, None - # Tensors for graph capture/replay - self.static_A, self.static_b, self.static_M = None, None, None - self.static_x, self.static_r, self.static_p, self.static_q, self.static_z = \ - None, None, None, None, None - self.static_rho_prev, self.static_rho_cur = None, None - - def forward(self, A: torch.Tensor, b: Tensor, x: Optional[Tensor]=None, - M: Optional[torch.Tensor]=None) -> Tensor: - ''' - Args: - A (Tensor): the input tensor. It is assumed to be a symmetric - positive-definite matrix. Layout is allowed to be COO, CSR, BSR, or dense. - b (Tensor): the tensor on the right hand side. Layout could be sparse or dense - but is only allowed to be a type that is compatible with the layout of A. - In other words, `A @ b` operation must be supported by the layout of A. - x (Tensor, optional): the initial guess for the solution. Default: ``None``. - M (Tensor, optional): the preconditioner for A. Layout is allowed to be COO, - CSR, BSR, or dense. Default: ``None``. - - Return: - Tensor: the solved tensor. Layout is the same as the layout of b. - ''' - if A.ndim == b.ndim + 1: - b = b.unsqueeze(-1) - else: - assert A.ndim == b.ndim, \ - 'The number of dimensions of A and b must be the same or one more than b' - - if x is None: - x = torch.zeros_like(b) - - bnrm2 = torch.linalg.norm(b, dim=0) - if (bnrm2 == 0).all(): - return b - atol = self.tol * bnrm2 - n = b.shape[-2] - - if self.maxiter is None: - maxiter = n * 10 - else: - maxiter = self.maxiter - - # Determine if CUDA graph can be used and if re-capture is needed - use_cuda_graph = A.is_cuda - - if use_cuda_graph: - re_capture_graph = (self.graph_first_iter is None or \ - self.static_A_shape != A.shape or \ - self.static_b_shape != b.shape or \ - self.static_M_is_none != (M is None) or \ - self.static_device != A.device) - - if re_capture_graph: - # Allocate static tensors and capture new graphs - self.static_A = A.clone() - self.static_b = b.clone() - self.static_x = x.clone() # Initial x - self.static_r = b - A @ x # Initial r - self.static_p = torch.zeros_like(b) # Will be updated - self.static_q = torch.empty_like(b) - self.static_z = torch.empty_like(b) - - # Initialize rho_prev and rho_cur with shape [1, 1] - self.static_rho_prev = torch.zeros(1, 1, device=A.device) - self.static_rho_cur = torch.zeros(1, 1, device=A.device) - - self.static_M_is_none = (M is None) - self.static_device = A.device - self.static_A_shape = A.shape - self.static_b_shape = b.shape - - if M is not None: - self.static_M = M.clone() - else: - self.static_M = None - - # Capture first iteration graph - self.graph_first_iter = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(self.graph_first_iter): - # Operations for first iteration - if not self.static_M_is_none: - torch.matmul(self.static_M, self.static_r, out=self.static_z) - else: - self.static_z.copy_(self.static_r) # z = r.clone() - self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) - self.static_p.copy_(self.static_z) # p = z.clone() - torch.matmul(self.static_A, self.static_p, out=self.static_q) - alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) - self.static_x.add_(alpha * self.static_p) - self.static_r.sub_(alpha * self.static_q) - self.static_rho_prev.copy_(self.static_rho_cur) - - # Capture subsequent iteration graph - self.graph_subsequent_iter = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(self.graph_subsequent_iter): - # Operations for subsequent iterations - if not self.static_M_is_none: - torch.matmul(self.static_M, self.static_r, out=self.static_z) - else: - self.static_z.copy_(self.static_r) # z = r.clone() - self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) - beta = self.static_rho_cur / self.static_rho_prev - self.static_p.mul_(beta).add_(self.static_z) - torch.matmul(self.static_A, self.static_p, out=self.static_q) - alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) - self.static_x.add_(alpha * self.static_p) - self.static_r.sub_(alpha * self.static_q) - self.static_rho_prev.copy_(self.static_rho_cur) - - # Now run the loop using the (newly captured or existing) graphs - self.static_A.copy_(A) - self.static_b.copy_(b) - self.static_x.copy_(x) - self.static_r.copy_(b - A @ x) # Initial r - if M is not None: - self.static_M.copy_(M) - - # First iteration - self.graph_first_iter.replay() - if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): - return self.static_x - - # Subsequent iterations - for iteration in range(1, maxiter): - self.graph_subsequent_iter.replay() - if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): - return self.static_x - return self.static_x - - else: # A is not on CUDA, or other conditions not met for graph, run original Python loop - r = b - A @ x if x.any() else b.clone() - rho_prev, p = None, None - - q = torch.empty_like(b) - if M is not None: - z = torch.empty_like(b) - else: - z = r.clone() - - for iteration in range(maxiter): - if (torch.linalg.norm(r, dim=0) < atol).all(): - return x - - if M is not None: - torch.matmul(M, r, out=z) - rho_cur = torch.matmul(r.mT, z) - if iteration > 0: - beta = rho_cur / rho_prev - p.mul_(beta).add_(z) - else: # First spin - p = z.clone() - - torch.matmul(A, p, out=q) - alpha = rho_cur / torch.matmul(p.mT, q) - x += alpha * p - r -= alpha * q - rho_prev = rho_cur - - return x diff --git a/tests/test_normal_operator.py b/tests/test_normal_operator.py new file mode 100644 index 0000000..b101514 --- /dev/null +++ b/tests/test_normal_operator.py @@ -0,0 +1,87 @@ +import pytest +import torch + +from bae.utils.linear_operator import NormalMatVec +from bae.utils.pysolvers import PCG + + +def test_normal_matvec_dense_matches_explicit(): + torch.manual_seed(0) + m, n = 8, 5 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + + op = NormalMatVec(J) + y_op = op.matvec(x) + y_ex = J.mT @ (J @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_dense_damping_matches_explicit(): + torch.manual_seed(1) + m, n = 7, 4 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + damping = 0.3 + + diag = J.square().sum(dim=0) + op = NormalMatVec(J, damping=damping) + y_op = op @ x + y_ex = J.mT @ (J @ x) + damping * diag * x + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_csr_matches_explicit_and_cached_diag(): + torch.manual_seed(2) + m, n = 10, 6 + dense = torch.randn(m, n, dtype=torch.float64) + mask = (torch.rand_like(dense) < 0.5) + dense = dense * mask + J = dense.to_sparse_csr() + x = torch.randn(n, dtype=torch.float64) + + diag = dense.square().sum(dim=0) + op = NormalMatVec(J, diag=diag) + y_op = op.matvec(x) + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_bsr_matches_explicit(): + torch.manual_seed(3) + m, n = 6, 4 + dense = torch.randn(m, n, dtype=torch.float64).contiguous() + x = torch.randn(n, dtype=torch.float64) + + try: + J_bsr = dense.to_sparse_bsr(blocksize=(2, 2)) + except Exception: + pytest.skip("BSR conversion not supported on this device/build.") + + op = NormalMatVec(J_bsr) + y_op = op @ x + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_pcg_smoke_with_normal_operator(): + torch.manual_seed(4) + m, n = 12, 5 + J = torch.randn(m, n, dtype=torch.float64) + damping = 1e-3 + op = NormalMatVec(J, damping=damping) + + b = torch.randn(n, dtype=torch.float64) + solver = PCG(tol=1e-10, maxiter=200) + x_pcg = solver(op, b) + + diag = J.square().sum(dim=0) + A_dense = J.mT @ J + damping * torch.diag(diag) + x_ex = torch.linalg.solve(A_dense, b) + + torch.testing.assert_close(x_pcg, x_ex, rtol=1e-6, atol=1e-6) + From 9c90acaeee378651fcd16be7236c1b1b15dab66d Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Tue, 23 Dec 2025 18:48:20 +0000 Subject: [PATCH 02/28] print peak cuda allocation --- ba_example.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/ba_example.py b/ba_example.py index dec3758..244ec27 100644 --- a/ba_example.py +++ b/ba_example.py @@ -31,6 +31,20 @@ file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) memory_snapshot_path = None +cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None + + +def _format_bytes(num_bytes: int) -> str: + units = ["B", "KiB", "MiB", "GiB", "TiB"] + size = float(num_bytes) + unit = units[0] + for unit in units: + if size < 1024.0 or unit == units[-1]: + break + size /= 1024.0 + if unit == "B": + return f"{int(size)} {unit}" + return f"{size:.2f} {unit}" if DEVICE.startswith("cuda") and torch.cuda.is_available(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -42,7 +56,7 @@ enabled="all", context="all", stacks="python", - device=torch.device(DEVICE), + device=cuda_device, clear_history=True, ) @@ -82,22 +96,36 @@ print("Initial loss", optimizer.model.loss(input, None).item()) +if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) + torch.cuda.reset_peak_memory_stats(cuda_device) + start = perf_counter() for idx in range(20): loss = optimizer.step(input) print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) if memory_snapshot_path: - torch.cuda.synchronize() + torch.cuda.synchronize(cuda_device) torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) print(f"CUDA memory snapshot saved to {memory_snapshot_path}") # exit() -torch.cuda.synchronize() +if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) end = perf_counter() print('Time', end - start) +if cuda_device is not None and torch.cuda.is_available(): + peak_allocated = torch.cuda.max_memory_allocated(cuda_device) + try: + peak_reserved = torch.cuda.max_memory_reserved(cuda_device) + except AttributeError: # older PyTorch + peak_reserved = torch.cuda.max_memory_cached(cuda_device) + print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") + print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") + print('Ending loss:', least_square_error( model.pose, model.points_3d, From 6256e799f409ff2931306df7a2aebec572208380 Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 03:12:06 +0000 Subject: [PATCH 03/28] add warp memory pool report --- ba_example.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ba_example.py b/ba_example.py index 244ec27..04e2160 100644 --- a/ba_example.py +++ b/ba_example.py @@ -27,11 +27,15 @@ OPTIMIZE_INTRINSICS = True USE_QUATERNIONS = True +REPORT_WARP_MEMPOOL = True file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) memory_snapshot_path = None cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None +warp_device = None +warp_mempool_start_current = None +warp_mempool_start_high = None def _format_bytes(num_bytes: int) -> str: @@ -60,6 +64,17 @@ def _format_bytes(num_bytes: int) -> str: clear_history=True, ) +if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"): + try: + if wp.is_cuda_available(): + warp_device = wp.get_device("cuda:0" if DEVICE == "cuda" else DEVICE) + if not wp.is_mempool_enabled(warp_device): + wp.set_mempool_enabled(warp_device, True) + warp_mempool_start_current = wp.get_mempool_used_mem_current(warp_device) + warp_mempool_start_high = wp.get_mempool_used_mem_high(warp_device) + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + if OPTIMIZE_INTRINSICS: NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9 else: @@ -126,6 +141,15 @@ def _format_bytes(num_bytes: int) -> str: print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") +if warp_device is not None and warp_mempool_start_current is not None and warp_mempool_start_high is not None: + try: + warp_current = wp.get_mempool_used_mem_current(warp_device) + warp_high = wp.get_mempool_used_mem_high(warp_device) + print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} (Δ {_format_bytes(warp_current - warp_mempool_start_current)})") + print(f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} (Δ {_format_bytes(warp_high - warp_mempool_start_high)})") + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + print('Ending loss:', least_square_error( model.pose, model.points_3d, From 3a5ce9bd4fd015de5bbc63e6f6cfb7fdcb8f1ad8 Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 04:26:48 +0000 Subject: [PATCH 04/28] use `A._get_Jt` when matrix_free_normal --- bae/optim/optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 1bc03ba..7a40a70 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -67,17 +67,17 @@ def step(self, input, target=None, weight=None): J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1) self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) - J_T = J.mT self.reject_count = 0 - J_T = J_T.to_sparse_csr() J = J.to_sparse_csr() - rhs = -J_T @ R.view(-1, 1) if self.matrix_free_normal: diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) A = NormalMatVec(J, damping=0.0, diag=diag) + rhs = -(A._get_Jt() @ R.view(-1, 1)) diag_scale = 1.0 else: + J_T = J.mT.to_sparse_csr() + rhs = -J_T @ R.view(-1, 1) A = self.mm(J_T, J) diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) From 00641468e7c5c218e87c5ea3ba268c127c0417cc Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 03:11:03 +0000 Subject: [PATCH 05/28] add back schur by warp's matmul --- .gitignore | 1 + ba_example.py | 63 +++++++++++++++++- bae/optim/optimizer.py | 144 ++++++++++++++++++++++++++++++++++++++++- bae/sparse/py_ops.py | 7 +- 4 files changed, 211 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a5f0941..4dc0a28 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ tmp_debug/* task.md *.pickle *.png +.warp_cache/* \ No newline at end of file diff --git a/ba_example.py b/ba_example.py index 04e2160..34b0100 100644 --- a/ba_example.py +++ b/ba_example.py @@ -3,13 +3,17 @@ from datetime import datetime import torch import pypose as pp +import warp as wp +from warp import sparse as wpsparse from ba_helpers import Reproj, least_square_error +from bae.optim.optimizer import Schur from datapipes.bal_loader import get_problem, read_bal_data from bae.sparse.py_ops import * from bae.sparse.solve import * from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS +from bae.sparse.warp_wrappers import format_vec_for_bsr TARGET_DATASET = "ladybug" TARGET_PROBLEM = "problem-1723-156502-pre" @@ -91,13 +95,68 @@ def _format_bytes(num_bytes: int) -> str: "point_indices": trimmed_dataset['point_index_of_observations'] } +class TrustRegion(pp.optim.strategy.TrustRegion): + def update(self, pg, last, loss, J, D, R, *args, **kwargs): + # PyTorch CUDA BSR matvec currently assumes square blocks; allow passing Warp + # BSR matrices via `Jwp` to use Warp's bsrmv for rectangular blocks. + Jwp = kwargs.get("Jwp") + if Jwp is not None: + J = Jwp + JD = None + for i in range(len(D)): + if JD is None: + if Jwp is not None: + Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) + JD = wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() + else: + JD = J[i] @ D[i].flatten() + else: + if Jwp is not None: + Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) + JD += wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() + else: + JD += J[i] @ D[i].flatten() + JD = JD[..., None] + quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + pg['radius'] = 1. / pg['damping'] + if quality > pg['high']: + pg['radius'] = pg['up'] * pg['radius'] + pg['down'] = self.down + elif quality > pg['low']: + pg['radius'] = pg['radius'] + pg['down'] = self.down + else: + pg['radius'] = pg['radius'] * pg['down'] + pg['down'] = pg['down'] * pg['factor'] + pg['down'] = max(self.min, min(pg['down'], self.max)) + pg['radius'] = max(self.min, min(pg['radius'], self.max)) + pg['damping'] = 1. / pg['radius'] +class Adaptive(pp.optim.strategy.Adaptive): + def update(self, pg, last, loss, J, D, R, *args, **kwargs): + J = [i.to_sparse_coo() for i in J] + JD = None + for i in range(len(D)): + if JD is None: + JD = J[i] @ D[i] + else: + JD += J[i] @ D[i] + JD = JD[..., None] + quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + if quality > pg['high']: + pg['damping'] = pg['damping'] * pg['down'] + elif quality > pg['low']: + pg['damping'] = pg['damping'] + else: + pg['damping'] = pg['damping'] * pg['up'] + pg['damping'] = max(self.min, min(pg['damping'], self.max)) + model = Reproj( trimmed_dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(), trimmed_dataset['points_3d'].clone() ).to(DEVICE) -strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4) +strategy = TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) # or CuDSS() -optimizer = LM(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) +optimizer = Schur(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 7a40a70..9c5a445 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -3,9 +3,12 @@ import torch from pypose.optim import LevenbergMarquardt as ppLM import pypose as pp + +from warp.optim import linear +from bae.sparse.warp_wrappers import format_vec_for_bsr, torchbsr2wp, wp2torchbsr from ..autograd.graph import backward, construct_sbt from ..autograd.function import TrackingTensor -from ..sparse.py_ops import diagonal_op_ +from ..sparse.py_ops import diagonal_op_, inv_op from ..sparse.spgemm import CuSparse from ..utils.linear_operator import NormalMatVec @@ -121,3 +124,142 @@ def update_parameter(self, params, step): param[:, 7:] += d.view(param.shape[0], -1)[:, 6:] else: param.add_(d.view(param.shape)) + +import warp as wp +from warp import sparse +class Schur(LM): + @torch.no_grad() + def step(self, input, target=None, weight=None): + for pg in self.param_groups: + self.reject_count = 0 + weight = self.weight if weight is None else weight + R = self.model(input, target) + + R = R[0] + J = jacobian(R, pg['params']) + J[0] = J[0] + J[1] = J[1] + + self.last = self.loss = self.loss if hasattr(self, 'loss') \ + else self.model.loss(input, target) + # torch.cuda.nvtx.range_push("JTJc") + J0wp = torchbsr2wp(J[0]) + J0twp = sparse.bsr_transposed(J0wp) + U = sparse.bsr_mm(J0twp, J0wp) + # torch.cuda.nvtx.range_pop() + # J0D = J[0].to_dense() + # UD = U.to_dense() + # torch.testing.assert_close(UD, J0D.mT @ J0D) + # del J0D + # del UD + # torch.cuda.nvtx.range_push("JTJp") + J1wp = torchbsr2wp(J[1]) + J1twp = sparse.bsr_transposed(J1wp) + V = sparse.bsr_mm(J1twp, J1wp) + # torch.cuda.nvtx.range_pop() + # J1D = J[1].to_dense() + # VD = V.to_dense() + # torch.testing.assert_close(VD, J1D.mT @ J1D) + # del J1D + # del VD + + # torch.cuda.nvtx.range_push("Clamp") + Upt = wp2torchbsr(U) + Vpt = wp2torchbsr(V) + diagonal_op_(Upt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + diagonal_op_(Vpt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + # torch.cuda.nvtx.range_pop() + + while self.last <= self.loss: + damping = pg['damping'] + R = R.reshape(-1) + + # torch.cuda.nvtx.range_push("Damp") + # damp = lambda x: x.pow(2) * damping + x + damp = partial(torch.mul, other=1+damping) + diagonal_op_(Upt, op=damp) + diagonal_op_(Vpt, op=damp) + # sparse.bsr_set_diag(U, sparse.bsr_get_diag(U) * (1+pg['damping'])) + # sparse.bsr_set_diag(V, sparse.bsr_get_diag(V) * (1+pg['damping'])) + # torch.cuda.nvtx.range_pop() + + # torch.cuda.nvtx.range_push("W") + W = J0twp @ J1wp + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Ic") + Rwp = format_vec_for_bsr(R, J0twp.block_shape) + Ic = sparse.bsr_mv(J0twp, Rwp, alpha=-1.0) + Ip = sparse.bsr_mv(J1twp, Rwp, alpha=-1.0) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Inv") + V_i = torchbsr2wp(inv_op(Vpt)) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("WVi") + WV_i = W @ V_i + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("rhs1") + rhs = sparse.bsr_mv(WV_i, Ip, y=Ic, alpha=-1.0, beta=1.0) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("lhs1") + Wt = W.transpose() + lhs = sparse.bsr_axpy(U, WV_i @ Wt, alpha=1.0, beta=-1.0) # this matrix is NOT symetric + # torch.cuda.nvtx.range_pop() + D_c = wp.zeros_like(rhs) + # torch.cuda.nvtx.range_push("Solve C") + solver_tol = getattr(self.solver, "tol", None) + solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 + results = linear.cg( + A=lhs, + b=rhs, + x=D_c, + tol=solver_tol, + maxiter=solver_maxiter, + M=linear.preconditioner(lhs), + ) + + # torch.cuda.nvtx.range_pop() + + # torch.cuda.nvtx.range_push("rhs2") + + rhs = sparse.bsr_mv(Wt, D_c, alpha=-1.0, beta=1.0, y=Ip) # rhs = Ip - Wt @ D_c + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("solve2") + lhs = V + D_p = wp.zeros_like(rhs) + results = linear.cg( + A=lhs, + b=rhs, + x=D_p, + tol=solver_tol, + maxiter=solver_maxiter, + M=linear.preconditioner(lhs), + ) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Update") + D_c = wp.to_torch(D_c).flatten() + D_p = wp.to_torch(D_p).flatten() + D = torch.cat([D_c, D_p]) + self.update_parameter(pg['params'], D) + # torch.cuda.nvtx.range_pop() + self.loss = self.model.loss(input, target) + print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) + # torch.cuda.nvtx.range_push("Strategy") + # self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1)) + # Pass Warp-format Jacobians as well so strategies can do bsrmv without + # hitting PyTorch's CUDA BSR matvec limitation for rectangular blocks. + self.strategy.update( + pg, + last=self.last, + loss=self.loss, + J=J, + Jwp=[J0wp, J1wp], + D=[D_c, D_p], + R=R.view(-1, 1), + ) + # torch.cuda.nvtx.range_pop() + if self.last < self.loss and self.reject_count < self.reject: # reject step + self.update_parameter(params = pg['params'], step = -D) + self.loss, self.reject_count = self.last, self.reject_count + 1 + else: + break + return self.loss diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 464241e..6f7ba0b 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -172,7 +172,12 @@ def inv_op(input): bsr_values = input.values() # 1 + 2 dimensional inv_values = torch.linalg.inv(bsr_values) - return torch.sparse_bsc_tensor(crow_indices, col_indices, inv_values) + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, + col_indices=col_indices, + values=inv_values, + size=input.shape, + dtype=input.dtype, ) def to_cooh(input): crow_indices = input.crow_indices() # b + 1 dimensional From acd1b3c6873eb5d872dc8d96d94498e37595cd95 Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Tue, 13 Jan 2026 22:13:03 -0500 Subject: [PATCH 06/28] safely import cudss --- ba_example.py | 1 - bae/sparse/__init__.py | 11 +++++++++-- bae/utils/pysolvers.py | 14 +++++++++++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ba_example.py b/ba_example.py index 550c25b..b578641 100644 --- a/ba_example.py +++ b/ba_example.py @@ -5,7 +5,6 @@ from ba_helpers import Reproj, least_square_error from datapipes.bal_loader import get_problem, read_bal_data from bae.sparse.py_ops import * -from bae.sparse.solve import * from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS diff --git a/bae/sparse/__init__.py b/bae/sparse/__init__.py index 3599d65..ac7d4e7 100644 --- a/bae/sparse/__init__.py +++ b/bae/sparse/__init__.py @@ -1,6 +1,13 @@ from .bsr import * from .bsr_cuda import * from .py_ops import * -from .solve import * +try: + from .solve import * +except ImportError: + # `bae.sparse.solve` depends on NVIDIA cuDSS. Some environments ship cuDSS built + # against a newer CUDA/cuBLAS (e.g. `libcublas.so.13`), which makes importing + # this package fail even if you don't use the direct solver. Keep the rest of + # the sparse ops usable and let callers opt into `bae.sparse.solve` explicitly. + pass from .conversion import * -from .warp_wrappers import * \ No newline at end of file +from .warp_wrappers import * diff --git a/bae/utils/pysolvers.py b/bae/utils/pysolvers.py index 772f6ab..e797a24 100644 --- a/bae/utils/pysolvers.py +++ b/bae/utils/pysolvers.py @@ -4,7 +4,19 @@ from pypose.optim.solver import CG from bae.sparse.py_ops import spdiags_ -from bae.sparse.solve import CuDirectSparseSolver as CuDSS +try: + from bae.sparse.solve import CuDirectSparseSolver as CuDSS +except Exception as e: + _cudss_import_error = e + + class CuDSS(torch.nn.Module): + def __init__(self, *args, **kwargs): + raise ImportError( + "CuDSS solver is unavailable because `bae.sparse.solve` failed to import. " + "This is commonly caused by a CUDA/cuDSS/cuBLAS mismatch (e.g. missing " + "`libcublas.so.13`). Use `PCG(...)` instead, or install a cuDSS build that " + "matches your CUDA toolkit." + ) from _cudss_import_error class PCG(CG): From 91c8ade1a8da0c4b59085b037cc0581a01196140 Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Fri, 19 Dec 2025 19:39:08 -0500 Subject: [PATCH 07/28] Add future plans section to README Added a section for future plans including a new backend for distributed solver. --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 887b944..d87f085 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,9 @@ - **PyTorch Integration**: Seamlessly integrates with PyTorch's automatic differentiation framework - **Levenberg-Marquardt Optimizer**: Custom implementation of the LM algorithm for non-linear least squares problems +### Future Plan +- [ ] An new backend for [distributed solver](https://github.com/NVIDIA/AMGX) + ## Installation ### Prerequisites From 19774c3a49408570c5f0780026be9ca61befe407 Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sat, 13 Dec 2025 04:32:24 +0000 Subject: [PATCH 08/28] add normal matvec and memory profiler --- ba_example.py | 36 +++++- bae/optim/optimizer.py | 26 ++-- bae/sparse/py_ops.py | 25 ++-- bae/utils/__init__.py | 1 + bae/utils/linear_operator.py | 163 ++++++++++++++++++++++++ bae/utils/pysolvers.py | 230 +--------------------------------- tests/test_normal_operator.py | 87 +++++++++++++ 7 files changed, 317 insertions(+), 251 deletions(-) create mode 100644 bae/utils/linear_operator.py create mode 100644 tests/test_normal_operator.py diff --git a/ba_example.py b/ba_example.py index b578641..1bcf2bd 100644 --- a/ba_example.py +++ b/ba_example.py @@ -1,4 +1,6 @@ from time import perf_counter +from pathlib import Path +from datetime import datetime import torch import pypose as pp @@ -8,13 +10,13 @@ from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS -# TARGET_DATASET = "ladybug" -# TARGET_PROBLEM = "problem-1723-156502-pre" +TARGET_DATASET = "ladybug" +TARGET_PROBLEM = "problem-1723-156502-pre" # TARGET_PROBLEM = "problem-49-7776-pre" # TARGET_PROBLEM = "problem-1695-155710-pre" # TARGET_PROBLEM = "problem-969-105826-pre" -TARGET_DATASET = "trafalgar" -TARGET_PROBLEM = "problem-257-65132-pre" +# TARGET_DATASET = "trafalgar" +# TARGET_PROBLEM = "problem-257-65132-pre" # TARGET_DATASET = "dubrovnik" # TARGET_PROBLEM = "problem-356-226730-pre" @@ -27,6 +29,21 @@ file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) +memory_snapshot_path = None + +if DEVICE.startswith("cuda") and torch.cuda.is_available(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + snapshot_dir = Path("memory_traces") + snapshot_dir.mkdir(exist_ok=True) + memory_snapshot_path = snapshot_dir / f"{file_name}_cuda_memory_{timestamp}.pickle" + # Record allocator events so we can inspect GPU memory usage after the run. + torch.cuda.memory._record_memory_history( + enabled="all", + context="all", + stacks="python", + device=torch.device(DEVICE), + clear_history=True, + ) if OPTIMIZE_INTRINSICS: NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9 @@ -50,7 +67,9 @@ ).to(DEVICE) strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) # or CuDSS() -optimizer = LM(model, strategy=strategy, solver=solver, reject=30) +optimizer = LM(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) + + print('Loss:', least_square_error( model.pose, @@ -67,6 +86,13 @@ loss = optimizer.step(input) print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) +if memory_snapshot_path: + torch.cuda.synchronize() + torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) + print(f"CUDA memory snapshot saved to {memory_snapshot_path}") + +# exit() + torch.cuda.synchronize() end = perf_counter() print('Time', end - start) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 6d89626..70c610f 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -7,12 +7,14 @@ from ..autograd.function import TrackingTensor from ..sparse.py_ops import diagonal_op_ from ..sparse.spgemm import CuSparse +from ..utils.linear_operator import NormalMatVec class LM(ppLM): - def __init__(self, *args, **kwargs): + def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): + self.matrix_free_normal = matrix_free_normal super(LM, self).__init__(*args, **kwargs) self.mm = CuSparse() @@ -22,24 +24,34 @@ def step(self, input, target=None, weight=None): weight = self.weight if weight is None else weight R = list(self.model(input)) R = R[0] - J = jacobian(R, pg['params']) + J_list = jacobian(R, pg['params']) if isinstance(R, TrackingTensor): R = R.tensor() - J = torch.cat([j.to_sparse_coo() for j in J], dim=-1) + J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1) self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) J_T = J.mT self.reject_count = 0 J_T = J_T.to_sparse_csr() J = J.to_sparse_csr() - A = self.mm(J_T, J) + rhs = -J_T @ R.view(-1, 1) - diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + if self.matrix_free_normal: + diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) + A = NormalMatVec(J, damping=0.0, diag=diag) + diag_scale = 1.0 + else: + A = self.mm(J_T, J) + diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) while self.last <= self.loss: - diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) + if self.matrix_free_normal: + diag_scale *= 1.0 + pg['damping'] + A.set_damping(diag_scale - 1.0) + else: + diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) try: - D = self.solver(A, -J_T @ R.view(-1, 1)) + D = self.solver(A, rhs) D = D[:, None] except Exception as e: print(e, "\nLinear solver failed. Breaking optimization step...") diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 19164f3..464241e 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -232,16 +232,15 @@ def bsr2bsc(J): sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCPU') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCUDA') -if True: - crow_indices = torch.tensor([0, 2, 4]) - col_indices = torch.tensor([0, 1, 0, 1]) - values = torch.tensor([[[0, 1, 2], [6, 7, 8]], - [[3, 4, 5], [9, 10, 11]], - [[12, 13, 14], [18, 19, 20]], - [[15, 16, 17], [21, 22, 23]]]) - bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) - bsr = bsr.to('cuda') - csr = bsr.to_sparse_coo().to_sparse_csr() - # print(csr) - output = diagonal_op_triton_(csr) - # print(output) \ No newline at end of file +if __name__ == "__main__": + if torch.cuda.is_available(): + crow_indices = torch.tensor([0, 2, 4]) + col_indices = torch.tensor([0, 1, 0, 1]) + values = torch.tensor([[[0, 1, 2], [6, 7, 8]], + [[3, 4, 5], [9, 10, 11]], + [[12, 13, 14], [18, 19, 20]], + [[15, 16, 17], [21, 22, 23]]]) + bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) + bsr = bsr.to('cuda') + csr = bsr.to_sparse_coo().to_sparse_csr() + output = diagonal_op_triton_(csr) diff --git a/bae/utils/__init__.py b/bae/utils/__init__.py index e69de29..48ea6d3 100644 --- a/bae/utils/__init__.py +++ b/bae/utils/__init__.py @@ -0,0 +1 @@ +from .linear_operator import NormalMatVec diff --git a/bae/utils/linear_operator.py b/bae/utils/linear_operator.py new file mode 100644 index 0000000..070badf --- /dev/null +++ b/bae/utils/linear_operator.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Optional, Union, Any, Tuple + +import torch +from torch import Tensor + + +class NormalMatVec: + r"""Matrix-free normal-equation linear operator. + + Given a Jacobian J (dense, CSR/COO, or BSR), this operator represents: + + A x = J^T (J x) + damping * diag(J^T J) * x + + where diag(J^T J) is computed as the column-wise sum of squares of J. + """ + + def __init__( + self, + J: Tensor, + damping: Union[float, Tensor] = 0.0, + diag: Optional[Tensor] = None, + ): + if not torch.is_tensor(J): + raise TypeError("J must be a torch.Tensor") + if J.ndim != 2: + raise ValueError("J must be 2-D") + + self.J: Tensor = J + self._Jt: Optional[Tensor] = None + + self.device = J.device + self.dtype = J.dtype + self.layout = J.layout if J.layout != torch.sparse_coo else torch.sparse_csr + + self.shape: Tuple[int, int] = (J.shape[1], J.shape[1]) + self.ndim: int = 2 + + self._diag: Tensor = diag if diag is not None else self._compute_diag(J) + if self._diag.ndim != 1 or self._diag.numel() != J.shape[1]: + raise ValueError("diag must be 1-D with length equal to J.shape[1]") + + self.set_damping(damping) + + def set_damping(self, damping: Union[float, Tensor]) -> None: + if isinstance(damping, Tensor): + if damping.numel() != 1: + raise ValueError("damping tensor must be scalar") + self.damping = damping.to(device=self.device, dtype=self.dtype) + else: + self.damping = float(damping) + + def diagonal(self) -> Tensor: + damp = self._damping_value() + if damp is None: + return self._diag + return self._diag * (1.0 + damp) + + def matvec(self, x: Tensor) -> Tensor: + if x.ndim == 1: + x2d = x.unsqueeze(-1) + elif x.ndim == 2: + x2d = x + else: + raise ValueError("x must be 1-D or 2-D") + + if x2d.device != self.device or x2d.dtype != self.dtype: + x2d = x2d.to(device=self.device, dtype=self.dtype) + + y = self.J @ x2d + Jt = self._get_Jt() + z = Jt @ y + + damp = self._damping_value() + if damp is not None: + z = z + damp * self._diag.unsqueeze(-1) * x2d + + return z.squeeze(-1) if x.ndim == 1 else z + + def __call__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + def __matmul__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + @classmethod + def __torch_function__( # type: ignore[override] + cls, func: Any, types: Any, args: Tuple[Any, ...] = (), kwargs: Optional[dict] = None + ): + if kwargs is None: + kwargs = {} + if func is torch.matmul and len(args) >= 2: + A, B = args[0], args[1] + if isinstance(A, cls): + out = kwargs.get("out", None) + result = A.matvec(B) + if out is not None: + out_tensor = out[0] if isinstance(out, tuple) else out + out_tensor.copy_(result) + return out_tensor + return result + return NotImplemented + + def _get_Jt(self) -> Tensor: + if self._Jt is None: + Jt = self.J.mT + if Jt.layout == torch.sparse_csc: + Jt = Jt.to_sparse_csr() + elif Jt.layout == torch.sparse_bsc: + bs = Jt.values().shape[-2:] + Jt = Jt.to_sparse_bsr(blocksize=bs) + self._Jt = Jt + return self._Jt + + def _damping_value(self) -> Optional[Tensor]: + if isinstance(self.damping, Tensor): + if self.damping.item() == 0.0: + return None + return self.damping + if self.damping == 0.0: + return None + return torch.tensor(self.damping, device=self.device, dtype=self.dtype) + + @staticmethod + def _compute_diag(J: Tensor) -> Tensor: + if J.layout == torch.strided: + return J.square().sum(dim=0) + + if J.layout == torch.sparse_bsr: + values = J.values() + dm, dn = values.shape[-2], values.shape[-1] + col_blocks = J.col_indices() + contrib = values.square().sum(dim=-2) # (nnz_blocks, dn) + offsets = torch.arange(dn, device=contrib.device, dtype=col_blocks.dtype) + cols = (col_blocks[:, None] * dn + offsets[None, :]).reshape(-1).to(torch.int64) + contrib_flat = contrib.reshape(-1) + diag = torch.zeros(J.shape[1], device=contrib.device, dtype=contrib.dtype) + diag.scatter_add_(0, cols, contrib_flat) + return diag + + if J.layout == torch.sparse_csr: + values = J.values() + col = J.col_indices().to(torch.int64) + v2 = values.square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=values.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + if J.layout == torch.sparse_coo: + Jc = J.coalesce() + col = Jc.indices()[1].to(torch.int64) + v2 = Jc.values().square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=Jc.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + raise NotImplementedError(f"Unsupported J layout: {J.layout}") + diff --git a/bae/utils/pysolvers.py b/bae/utils/pysolvers.py index e797a24..d4fb706 100644 --- a/bae/utils/pysolvers.py +++ b/bae/utils/pysolvers.py @@ -25,14 +25,15 @@ def __init__(self, maxiter=None, tol=1e-5): def forward(self, A, b, x=None, M=None) -> torch.Tensor: if b.dim() == 1: b = b[..., None] - l_diag = A.diagonal() + l_diag = A.diagonal().clone() l_diag[l_diag.abs() < 1e-6] = 1e-6 M = spdiags_((1 / l_diag), None, shape=A.shape, layout=None) - if A.layout == torch.sparse_csr: + layout = getattr(A, "layout", torch.strided) + if layout == torch.sparse_csr: # M = M.to_sparse_csr() pass # A = M @ A - elif A.layout == torch.sparse_bsr: + elif layout == torch.sparse_bsr and isinstance(A, torch.Tensor): M = M.to_sparse_bsr(blocksize=A.values().shape[-2:]).to(A.device) # A = M @ A.to_sparse_bsc(blocksize=A.values().shape[-2:]) # b = M @ b @@ -62,226 +63,3 @@ def forward(self, A, b): # print(f"Linear Solver Error: {a_err}, relative error: {r_err}") return torch.from_numpy(x).to(A.device) - -# cuda graph version of the solver -class CG_(torch.nn.Module): - r'''The batched linear solver with conjugate gradient method. - - .. math:: - \mathbf{A}_i \bm{x}_i = \mathbf{b}_i, - - where :math:`\mathbf{A}_i \in \mathbb{C}^{M \times N}` and :math:`\bm{b}_i \in - \mathbb{C}^{M \times 1}` are the :math:`i`-th item of batched linear equations. - - This function is a 1:1 replica of `scipy.sparse.linalg.cg `_. - The solution is consistent with the scipy version up to numerical precision. - Variable names are kept the same as the scipy version for easy reference. - We recommend using only non-batched or batch size 1 input for this solver, as - the batched version was not appeared in the original scipy version. When handling - sparse matrices, the batched computation may introduce additional overhead. - - Examples: - >>> # dense example - >>> import pypose.optim.solver as ppos - >>> A = torch.tensor([[0.1802967, 0.3151198, 0.4548111, 0.3860016, 0.2870615], - [0.3151198, 1.4575327, 1.5533425, 1.0540756, 1.0795838], - [0.4548111, 1.5533425, 2.3674474, 1.1222278, 1.2365348], - [0.3860016, 1.0540756, 1.1222278, 1.3748058, 1.2223261], - [0.2870615, 1.0795838, 1.2365348, 1.2223261, 1.2577004]]) - >>> b = torch.tensor([[ 2.64306851], - [-0.03593633], - [ 0.73612658], - [ 0.51501254], - [-0.26689271]]) - >>> solver = ppos.CG() - >>> x = solver(A, b) - tensor([[246.4098], - [ 22.6997], - [-56.9239], - [-161.7914], - [137.2683]]) - - >>> # sparse csr example - >>> import pypose.optim.solver as ppos - >>> crow_indices = torch.tensor([0, 2, 4]) - >>> col_indices = torch.tensor([0, 1, 0, 1]) - >>> values = torch.tensor([1, 2, 3, 4], dtype=torch.float) - >>> A = torch.sparse_csr_tensor(crow_indices, col_indices, values) - >>> A.to_dense() # visualize - tensor([[1., 2.], - [3., 4.]]) - >>> b = torch.tensor([[1.], [2.]]) - >>> solver = ppos.CG() - >>> x = solver(A, b) - tensor([-4.4052e-05, 5.0003e-01]) - - ''' - def __init__(self, maxiter=None, tol=1e-5): - super().__init__() - self.maxiter, self.tol = maxiter, tol - self.graph_first_iter = None - self.graph_subsequent_iter = None - self.static_A_shape, self.static_b_shape, self.static_M_is_none, self.static_device = \ - None, None, None, None - # Tensors for graph capture/replay - self.static_A, self.static_b, self.static_M = None, None, None - self.static_x, self.static_r, self.static_p, self.static_q, self.static_z = \ - None, None, None, None, None - self.static_rho_prev, self.static_rho_cur = None, None - - def forward(self, A: torch.Tensor, b: Tensor, x: Optional[Tensor]=None, - M: Optional[torch.Tensor]=None) -> Tensor: - ''' - Args: - A (Tensor): the input tensor. It is assumed to be a symmetric - positive-definite matrix. Layout is allowed to be COO, CSR, BSR, or dense. - b (Tensor): the tensor on the right hand side. Layout could be sparse or dense - but is only allowed to be a type that is compatible with the layout of A. - In other words, `A @ b` operation must be supported by the layout of A. - x (Tensor, optional): the initial guess for the solution. Default: ``None``. - M (Tensor, optional): the preconditioner for A. Layout is allowed to be COO, - CSR, BSR, or dense. Default: ``None``. - - Return: - Tensor: the solved tensor. Layout is the same as the layout of b. - ''' - if A.ndim == b.ndim + 1: - b = b.unsqueeze(-1) - else: - assert A.ndim == b.ndim, \ - 'The number of dimensions of A and b must be the same or one more than b' - - if x is None: - x = torch.zeros_like(b) - - bnrm2 = torch.linalg.norm(b, dim=0) - if (bnrm2 == 0).all(): - return b - atol = self.tol * bnrm2 - n = b.shape[-2] - - if self.maxiter is None: - maxiter = n * 10 - else: - maxiter = self.maxiter - - # Determine if CUDA graph can be used and if re-capture is needed - use_cuda_graph = A.is_cuda - - if use_cuda_graph: - re_capture_graph = (self.graph_first_iter is None or \ - self.static_A_shape != A.shape or \ - self.static_b_shape != b.shape or \ - self.static_M_is_none != (M is None) or \ - self.static_device != A.device) - - if re_capture_graph: - # Allocate static tensors and capture new graphs - self.static_A = A.clone() - self.static_b = b.clone() - self.static_x = x.clone() # Initial x - self.static_r = b - A @ x # Initial r - self.static_p = torch.zeros_like(b) # Will be updated - self.static_q = torch.empty_like(b) - self.static_z = torch.empty_like(b) - - # Initialize rho_prev and rho_cur with shape [1, 1] - self.static_rho_prev = torch.zeros(1, 1, device=A.device) - self.static_rho_cur = torch.zeros(1, 1, device=A.device) - - self.static_M_is_none = (M is None) - self.static_device = A.device - self.static_A_shape = A.shape - self.static_b_shape = b.shape - - if M is not None: - self.static_M = M.clone() - else: - self.static_M = None - - # Capture first iteration graph - self.graph_first_iter = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(self.graph_first_iter): - # Operations for first iteration - if not self.static_M_is_none: - torch.matmul(self.static_M, self.static_r, out=self.static_z) - else: - self.static_z.copy_(self.static_r) # z = r.clone() - self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) - self.static_p.copy_(self.static_z) # p = z.clone() - torch.matmul(self.static_A, self.static_p, out=self.static_q) - alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) - self.static_x.add_(alpha * self.static_p) - self.static_r.sub_(alpha * self.static_q) - self.static_rho_prev.copy_(self.static_rho_cur) - - # Capture subsequent iteration graph - self.graph_subsequent_iter = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(self.graph_subsequent_iter): - # Operations for subsequent iterations - if not self.static_M_is_none: - torch.matmul(self.static_M, self.static_r, out=self.static_z) - else: - self.static_z.copy_(self.static_r) # z = r.clone() - self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) - beta = self.static_rho_cur / self.static_rho_prev - self.static_p.mul_(beta).add_(self.static_z) - torch.matmul(self.static_A, self.static_p, out=self.static_q) - alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) - self.static_x.add_(alpha * self.static_p) - self.static_r.sub_(alpha * self.static_q) - self.static_rho_prev.copy_(self.static_rho_cur) - - # Now run the loop using the (newly captured or existing) graphs - self.static_A.copy_(A) - self.static_b.copy_(b) - self.static_x.copy_(x) - self.static_r.copy_(b - A @ x) # Initial r - if M is not None: - self.static_M.copy_(M) - - # First iteration - self.graph_first_iter.replay() - if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): - return self.static_x - - # Subsequent iterations - for iteration in range(1, maxiter): - self.graph_subsequent_iter.replay() - if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): - return self.static_x - return self.static_x - - else: # A is not on CUDA, or other conditions not met for graph, run original Python loop - r = b - A @ x if x.any() else b.clone() - rho_prev, p = None, None - - q = torch.empty_like(b) - if M is not None: - z = torch.empty_like(b) - else: - z = r.clone() - - for iteration in range(maxiter): - if (torch.linalg.norm(r, dim=0) < atol).all(): - return x - - if M is not None: - torch.matmul(M, r, out=z) - rho_cur = torch.matmul(r.mT, z) - if iteration > 0: - beta = rho_cur / rho_prev - p.mul_(beta).add_(z) - else: # First spin - p = z.clone() - - torch.matmul(A, p, out=q) - alpha = rho_cur / torch.matmul(p.mT, q) - x += alpha * p - r -= alpha * q - rho_prev = rho_cur - - return x diff --git a/tests/test_normal_operator.py b/tests/test_normal_operator.py new file mode 100644 index 0000000..b101514 --- /dev/null +++ b/tests/test_normal_operator.py @@ -0,0 +1,87 @@ +import pytest +import torch + +from bae.utils.linear_operator import NormalMatVec +from bae.utils.pysolvers import PCG + + +def test_normal_matvec_dense_matches_explicit(): + torch.manual_seed(0) + m, n = 8, 5 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + + op = NormalMatVec(J) + y_op = op.matvec(x) + y_ex = J.mT @ (J @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_dense_damping_matches_explicit(): + torch.manual_seed(1) + m, n = 7, 4 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + damping = 0.3 + + diag = J.square().sum(dim=0) + op = NormalMatVec(J, damping=damping) + y_op = op @ x + y_ex = J.mT @ (J @ x) + damping * diag * x + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_csr_matches_explicit_and_cached_diag(): + torch.manual_seed(2) + m, n = 10, 6 + dense = torch.randn(m, n, dtype=torch.float64) + mask = (torch.rand_like(dense) < 0.5) + dense = dense * mask + J = dense.to_sparse_csr() + x = torch.randn(n, dtype=torch.float64) + + diag = dense.square().sum(dim=0) + op = NormalMatVec(J, diag=diag) + y_op = op.matvec(x) + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_bsr_matches_explicit(): + torch.manual_seed(3) + m, n = 6, 4 + dense = torch.randn(m, n, dtype=torch.float64).contiguous() + x = torch.randn(n, dtype=torch.float64) + + try: + J_bsr = dense.to_sparse_bsr(blocksize=(2, 2)) + except Exception: + pytest.skip("BSR conversion not supported on this device/build.") + + op = NormalMatVec(J_bsr) + y_op = op @ x + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_pcg_smoke_with_normal_operator(): + torch.manual_seed(4) + m, n = 12, 5 + J = torch.randn(m, n, dtype=torch.float64) + damping = 1e-3 + op = NormalMatVec(J, damping=damping) + + b = torch.randn(n, dtype=torch.float64) + solver = PCG(tol=1e-10, maxiter=200) + x_pcg = solver(op, b) + + diag = J.square().sum(dim=0) + A_dense = J.mT @ J + damping * torch.diag(diag) + x_ex = torch.linalg.solve(A_dense, b) + + torch.testing.assert_close(x_pcg, x_ex, rtol=1e-6, atol=1e-6) + From 4ca9c86408f68651d67dc3cca14961443e2ddeed Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Tue, 23 Dec 2025 18:48:20 +0000 Subject: [PATCH 09/28] print peak cuda allocation --- ba_example.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/ba_example.py b/ba_example.py index 1bcf2bd..e427cd3 100644 --- a/ba_example.py +++ b/ba_example.py @@ -30,6 +30,20 @@ file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) memory_snapshot_path = None +cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None + + +def _format_bytes(num_bytes: int) -> str: + units = ["B", "KiB", "MiB", "GiB", "TiB"] + size = float(num_bytes) + unit = units[0] + for unit in units: + if size < 1024.0 or unit == units[-1]: + break + size /= 1024.0 + if unit == "B": + return f"{int(size)} {unit}" + return f"{size:.2f} {unit}" if DEVICE.startswith("cuda") and torch.cuda.is_available(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -41,7 +55,7 @@ enabled="all", context="all", stacks="python", - device=torch.device(DEVICE), + device=cuda_device, clear_history=True, ) @@ -81,22 +95,36 @@ print("Initial loss", optimizer.model.loss(input, None).item()) +if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) + torch.cuda.reset_peak_memory_stats(cuda_device) + start = perf_counter() for idx in range(20): loss = optimizer.step(input) print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) if memory_snapshot_path: - torch.cuda.synchronize() + torch.cuda.synchronize(cuda_device) torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) print(f"CUDA memory snapshot saved to {memory_snapshot_path}") # exit() -torch.cuda.synchronize() +if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) end = perf_counter() print('Time', end - start) +if cuda_device is not None and torch.cuda.is_available(): + peak_allocated = torch.cuda.max_memory_allocated(cuda_device) + try: + peak_reserved = torch.cuda.max_memory_reserved(cuda_device) + except AttributeError: # older PyTorch + peak_reserved = torch.cuda.max_memory_cached(cuda_device) + print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") + print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") + print('Ending loss:', least_square_error( model.pose, model.points_3d, From b71f1a3dbfde6682be8a43b78de514c8b0834ccd Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 03:12:06 +0000 Subject: [PATCH 10/28] add warp memory pool report --- ba_example.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ba_example.py b/ba_example.py index e427cd3..b74937d 100644 --- a/ba_example.py +++ b/ba_example.py @@ -26,11 +26,15 @@ OPTIMIZE_INTRINSICS = True USE_QUATERNIONS = True +REPORT_WARP_MEMPOOL = True file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) memory_snapshot_path = None cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None +warp_device = None +warp_mempool_start_current = None +warp_mempool_start_high = None def _format_bytes(num_bytes: int) -> str: @@ -59,6 +63,17 @@ def _format_bytes(num_bytes: int) -> str: clear_history=True, ) +if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"): + try: + if wp.is_cuda_available(): + warp_device = wp.get_device("cuda:0" if DEVICE == "cuda" else DEVICE) + if not wp.is_mempool_enabled(warp_device): + wp.set_mempool_enabled(warp_device, True) + warp_mempool_start_current = wp.get_mempool_used_mem_current(warp_device) + warp_mempool_start_high = wp.get_mempool_used_mem_high(warp_device) + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + if OPTIMIZE_INTRINSICS: NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9 else: @@ -125,6 +140,15 @@ def _format_bytes(num_bytes: int) -> str: print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") +if warp_device is not None and warp_mempool_start_current is not None and warp_mempool_start_high is not None: + try: + warp_current = wp.get_mempool_used_mem_current(warp_device) + warp_high = wp.get_mempool_used_mem_high(warp_device) + print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} (Δ {_format_bytes(warp_current - warp_mempool_start_current)})") + print(f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} (Δ {_format_bytes(warp_high - warp_mempool_start_high)})") + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + print('Ending loss:', least_square_error( model.pose, model.points_3d, From d67886752f0840ddb9eb7028e209448378d181b7 Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 04:26:48 +0000 Subject: [PATCH 11/28] use `A._get_Jt` when matrix_free_normal --- bae/optim/optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 70c610f..247b864 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -30,17 +30,17 @@ def step(self, input, target=None, weight=None): J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1) self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) - J_T = J.mT self.reject_count = 0 - J_T = J_T.to_sparse_csr() J = J.to_sparse_csr() - rhs = -J_T @ R.view(-1, 1) if self.matrix_free_normal: diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) A = NormalMatVec(J, damping=0.0, diag=diag) + rhs = -(A._get_Jt() @ R.view(-1, 1)) diag_scale = 1.0 else: + J_T = J.mT.to_sparse_csr() + rhs = -J_T @ R.view(-1, 1) A = self.mm(J_T, J) diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) From d127b8828a8b2886df7858d0730bf8d7accf32fd Mon Sep 17 00:00:00 2001 From: "Zitong Zhan (PVE)" Date: Sun, 28 Dec 2025 03:11:03 +0000 Subject: [PATCH 12/28] add back schur by warp's matmul --- .gitignore | 1 + ba_example.py | 63 +++++++++++++++++- bae/optim/optimizer.py | 144 ++++++++++++++++++++++++++++++++++++++++- bae/sparse/py_ops.py | 7 +- 4 files changed, 211 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a5f0941..4dc0a28 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ tmp_debug/* task.md *.pickle *.png +.warp_cache/* \ No newline at end of file diff --git a/ba_example.py b/ba_example.py index b74937d..2105da4 100644 --- a/ba_example.py +++ b/ba_example.py @@ -3,12 +3,16 @@ from datetime import datetime import torch import pypose as pp +import warp as wp +from warp import sparse as wpsparse from ba_helpers import Reproj, least_square_error +from bae.optim.optimizer import Schur from datapipes.bal_loader import get_problem, read_bal_data from bae.sparse.py_ops import * from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS +from bae.sparse.warp_wrappers import format_vec_for_bsr TARGET_DATASET = "ladybug" TARGET_PROBLEM = "problem-1723-156502-pre" @@ -90,13 +94,68 @@ def _format_bytes(num_bytes: int) -> str: "point_indices": trimmed_dataset['point_index_of_observations'] } +class TrustRegion(pp.optim.strategy.TrustRegion): + def update(self, pg, last, loss, J, D, R, *args, **kwargs): + # PyTorch CUDA BSR matvec currently assumes square blocks; allow passing Warp + # BSR matrices via `Jwp` to use Warp's bsrmv for rectangular blocks. + Jwp = kwargs.get("Jwp") + if Jwp is not None: + J = Jwp + JD = None + for i in range(len(D)): + if JD is None: + if Jwp is not None: + Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) + JD = wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() + else: + JD = J[i] @ D[i].flatten() + else: + if Jwp is not None: + Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) + JD += wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() + else: + JD += J[i] @ D[i].flatten() + JD = JD[..., None] + quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + pg['radius'] = 1. / pg['damping'] + if quality > pg['high']: + pg['radius'] = pg['up'] * pg['radius'] + pg['down'] = self.down + elif quality > pg['low']: + pg['radius'] = pg['radius'] + pg['down'] = self.down + else: + pg['radius'] = pg['radius'] * pg['down'] + pg['down'] = pg['down'] * pg['factor'] + pg['down'] = max(self.min, min(pg['down'], self.max)) + pg['radius'] = max(self.min, min(pg['radius'], self.max)) + pg['damping'] = 1. / pg['radius'] +class Adaptive(pp.optim.strategy.Adaptive): + def update(self, pg, last, loss, J, D, R, *args, **kwargs): + J = [i.to_sparse_coo() for i in J] + JD = None + for i in range(len(D)): + if JD is None: + JD = J[i] @ D[i] + else: + JD += J[i] @ D[i] + JD = JD[..., None] + quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + if quality > pg['high']: + pg['damping'] = pg['damping'] * pg['down'] + elif quality > pg['low']: + pg['damping'] = pg['damping'] + else: + pg['damping'] = pg['damping'] * pg['up'] + pg['damping'] = max(self.min, min(pg['damping'], self.max)) + model = Reproj( trimmed_dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(), trimmed_dataset['points_3d'].clone() ).to(DEVICE) -strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4) +strategy = TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) # or CuDSS() -optimizer = LM(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) +optimizer = Schur(model, matrix_free_normal=True, strategy=strategy, solver=solver, reject=30) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 247b864..f6a338e 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -3,9 +3,12 @@ import torch from pypose.optim import LevenbergMarquardt as ppLM import pypose as pp + +from warp.optim import linear +from bae.sparse.warp_wrappers import format_vec_for_bsr, torchbsr2wp, wp2torchbsr from ..autograd.graph import jacobian from ..autograd.function import TrackingTensor -from ..sparse.py_ops import diagonal_op_ +from ..sparse.py_ops import diagonal_op_, inv_op from ..sparse.spgemm import CuSparse from ..utils.linear_operator import NormalMatVec @@ -84,3 +87,142 @@ def update_parameter(self, params, step): param[:, 7:] += d.view(param.shape[0], -1)[:, 6:] else: param.add_(d.view(param.shape)) + +import warp as wp +from warp import sparse +class Schur(LM): + @torch.no_grad() + def step(self, input, target=None, weight=None): + for pg in self.param_groups: + self.reject_count = 0 + weight = self.weight if weight is None else weight + R = self.model(input, target) + + R = R[0] + J = jacobian(R, pg['params']) + J[0] = J[0] + J[1] = J[1] + + self.last = self.loss = self.loss if hasattr(self, 'loss') \ + else self.model.loss(input, target) + # torch.cuda.nvtx.range_push("JTJc") + J0wp = torchbsr2wp(J[0]) + J0twp = sparse.bsr_transposed(J0wp) + U = sparse.bsr_mm(J0twp, J0wp) + # torch.cuda.nvtx.range_pop() + # J0D = J[0].to_dense() + # UD = U.to_dense() + # torch.testing.assert_close(UD, J0D.mT @ J0D) + # del J0D + # del UD + # torch.cuda.nvtx.range_push("JTJp") + J1wp = torchbsr2wp(J[1]) + J1twp = sparse.bsr_transposed(J1wp) + V = sparse.bsr_mm(J1twp, J1wp) + # torch.cuda.nvtx.range_pop() + # J1D = J[1].to_dense() + # VD = V.to_dense() + # torch.testing.assert_close(VD, J1D.mT @ J1D) + # del J1D + # del VD + + # torch.cuda.nvtx.range_push("Clamp") + Upt = wp2torchbsr(U) + Vpt = wp2torchbsr(V) + diagonal_op_(Upt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + diagonal_op_(Vpt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + # torch.cuda.nvtx.range_pop() + + while self.last <= self.loss: + damping = pg['damping'] + R = R.reshape(-1) + + # torch.cuda.nvtx.range_push("Damp") + # damp = lambda x: x.pow(2) * damping + x + damp = partial(torch.mul, other=1+damping) + diagonal_op_(Upt, op=damp) + diagonal_op_(Vpt, op=damp) + # sparse.bsr_set_diag(U, sparse.bsr_get_diag(U) * (1+pg['damping'])) + # sparse.bsr_set_diag(V, sparse.bsr_get_diag(V) * (1+pg['damping'])) + # torch.cuda.nvtx.range_pop() + + # torch.cuda.nvtx.range_push("W") + W = J0twp @ J1wp + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Ic") + Rwp = format_vec_for_bsr(R, J0twp.block_shape) + Ic = sparse.bsr_mv(J0twp, Rwp, alpha=-1.0) + Ip = sparse.bsr_mv(J1twp, Rwp, alpha=-1.0) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Inv") + V_i = torchbsr2wp(inv_op(Vpt)) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("WVi") + WV_i = W @ V_i + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("rhs1") + rhs = sparse.bsr_mv(WV_i, Ip, y=Ic, alpha=-1.0, beta=1.0) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("lhs1") + Wt = W.transpose() + lhs = sparse.bsr_axpy(U, WV_i @ Wt, alpha=1.0, beta=-1.0) # this matrix is NOT symetric + # torch.cuda.nvtx.range_pop() + D_c = wp.zeros_like(rhs) + # torch.cuda.nvtx.range_push("Solve C") + solver_tol = getattr(self.solver, "tol", None) + solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 + results = linear.cg( + A=lhs, + b=rhs, + x=D_c, + tol=solver_tol, + maxiter=solver_maxiter, + M=linear.preconditioner(lhs), + ) + + # torch.cuda.nvtx.range_pop() + + # torch.cuda.nvtx.range_push("rhs2") + + rhs = sparse.bsr_mv(Wt, D_c, alpha=-1.0, beta=1.0, y=Ip) # rhs = Ip - Wt @ D_c + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("solve2") + lhs = V + D_p = wp.zeros_like(rhs) + results = linear.cg( + A=lhs, + b=rhs, + x=D_p, + tol=solver_tol, + maxiter=solver_maxiter, + M=linear.preconditioner(lhs), + ) + # torch.cuda.nvtx.range_pop() + # torch.cuda.nvtx.range_push("Update") + D_c = wp.to_torch(D_c).flatten() + D_p = wp.to_torch(D_p).flatten() + D = torch.cat([D_c, D_p]) + self.update_parameter(pg['params'], D) + # torch.cuda.nvtx.range_pop() + self.loss = self.model.loss(input, target) + print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) + # torch.cuda.nvtx.range_push("Strategy") + # self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1)) + # Pass Warp-format Jacobians as well so strategies can do bsrmv without + # hitting PyTorch's CUDA BSR matvec limitation for rectangular blocks. + self.strategy.update( + pg, + last=self.last, + loss=self.loss, + J=J, + Jwp=[J0wp, J1wp], + D=[D_c, D_p], + R=R.view(-1, 1), + ) + # torch.cuda.nvtx.range_pop() + if self.last < self.loss and self.reject_count < self.reject: # reject step + self.update_parameter(params = pg['params'], step = -D) + self.loss, self.reject_count = self.last, self.reject_count + 1 + else: + break + return self.loss diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 464241e..6f7ba0b 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -172,7 +172,12 @@ def inv_op(input): bsr_values = input.values() # 1 + 2 dimensional inv_values = torch.linalg.inv(bsr_values) - return torch.sparse_bsc_tensor(crow_indices, col_indices, inv_values) + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, + col_indices=col_indices, + values=inv_values, + size=input.shape, + dtype=input.dtype, ) def to_cooh(input): crow_indices = input.crow_indices() # b + 1 dimensional From 3e4761d82f1fee819a7fcfc45c34cf0c22f53f03 Mon Sep 17 00:00:00 2001 From: Seokwoo Park Date: Mon, 20 Apr 2026 00:28:19 -0400 Subject: [PATCH 13/28] Preventing TrustRegion from accepting diverging steps --- ba_example.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/ba_example.py b/ba_example.py index 87bda06..7b4f71a 100644 --- a/ba_example.py +++ b/ba_example.py @@ -35,16 +35,19 @@ def _format_bytes(num_bytes: int) -> str: + sign = "-" if num_bytes < 0 else "" + size = float(abs(num_bytes)) units = ["B", "KiB", "MiB", "GiB", "TiB"] - size = float(num_bytes) - unit = units[0] + for unit in units: if size < 1024.0 or unit == units[-1]: break size /= 1024.0 + if unit == "B": - return f"{int(size)} {unit}" - return f"{size:.2f} {unit}" + return f"{sign}{int(size)} {unit}" + + return f"{sign}{size:.2f} {unit}" @map_transform @@ -93,7 +96,13 @@ def update(self, pg, last, loss, J, D, R, *args, **kwargs): else: JD += J[i] @ D[i].flatten() JD = JD[..., None] - quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + denom = -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + + if loss >= last or denom <= 0: + quality = -1.0 + else: + quality = (last - loss) / denom + pg['radius'] = 1. / pg['damping'] if quality > pg['high']: pg['radius'] = pg['up'] * pg['radius'] @@ -144,11 +153,6 @@ def main(): if isinstance(value, torch.Tensor) } - # input = { - # "observes": dataset["points_2d"], - # "cidx": dataset["camera_index_of_observations"], - # "pidx": dataset["point_index_of_observations"], - # } input = { "points_2d": dataset["points_2d"], "camera_indices": dataset["camera_index_of_observations"], @@ -205,16 +209,17 @@ def main(): loss = optimizer.step(input) print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) - if memory_snapshot_path: - torch.cuda.synchronize(cuda_device) - torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) - print(f"CUDA memory snapshot saved to {memory_snapshot_path}") - if cuda_device is not None and torch.cuda.is_available(): torch.cuda.synchronize(cuda_device) end = perf_counter() + print('Time', end - start) + if memory_snapshot_path: + torch.cuda.synchronize(cuda_device) + torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) + print(f"CUDA memory snapshot saved to {memory_snapshot_path}") + if cuda_device is not None and torch.cuda.is_available(): peak_allocated = torch.cuda.max_memory_allocated(cuda_device) try: From 5d9e2b2945133dd6be429996ba3ace4b2d1ea5ea Mon Sep 17 00:00:00 2001 From: Seokwoo Park Date: Tue, 28 Apr 2026 21:54:37 -0400 Subject: [PATCH 14/28] fix(optimizer/LM): Remove redundant solver calls so matrix_free_normal path runs --- bae/optim/optimizer.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 5e3ed0e..af5c3ff 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -1,8 +1,9 @@ -from functools import partial import torch -from pypose.optim import LevenbergMarquardt as ppLM import pypose as pp - +import warp as wp +from functools import partial +from pypose.optim import LevenbergMarquardt as ppLM +from warp import sparse from warp.optim import linear from bae.sparse.warp_wrappers import format_vec_for_bsr, torchbsr2wp, wp2torchbsr from ..autograd.graph import jacobian @@ -13,8 +14,6 @@ from ..utils.parameter import parameter_update_shape - - class LM(ppLM): def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): self.matrix_free_normal = matrix_free_normal @@ -25,16 +24,17 @@ def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): def step(self, input, target=None, weight=None): for pg in self.param_groups: weight = self.weight if weight is None else weight - R = list(self.model(input)) - R = R[0] + R = self.model(input)[0] J_list = jacobian(R, pg['params']) + if isinstance(R, TrackingTensor): R = R.tensor() - J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1) + + J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1).to_sparse_csr() + del J_list self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) self.reject_count = 0 - J = J.to_sparse_csr() if self.matrix_free_normal: diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) @@ -45,6 +45,7 @@ def step(self, input, target=None, weight=None): J_T = J.mT.to_sparse_csr() rhs = -J_T @ R.view(-1, 1) A = self.mm(J_T, J) + del J_T diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) while self.last <= self.loss: @@ -55,8 +56,6 @@ def step(self, input, target=None, weight=None): diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) try: D = self.solver(A, rhs) - D = D[:, None] - D = self.solver(A, -J_T @ R.view(-1, 1)) except Exception as e: print(e, "\nLinear solver failed. Breaking optimization step...") break @@ -87,8 +86,6 @@ def update_parameter(self, params, step): else: param.add_(d.view(param.shape)) -import warp as wp -from warp import sparse class Schur(LM): @torch.no_grad() @@ -226,4 +223,3 @@ def step(self, input, target=None, weight=None): else: break return self.loss - # param.add_(step_view) From e34bea206146214144d164f4e5444e83e9f3c652 Mon Sep 17 00:00:00 2001 From: Seokwoo Park Date: Tue, 28 Apr 2026 22:07:10 -0400 Subject: [PATCH 15/28] feat(optim/Schur): Add Matrix-Free path and matrix_free_normal branch --- bae/optim/optimizer.py | 186 +++++++++++++++++++++-------------------- 1 file changed, 96 insertions(+), 90 deletions(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index af5c3ff..44c6d82 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -93,133 +93,139 @@ def step(self, input, target=None, weight=None): for pg in self.param_groups: self.reject_count = 0 weight = self.weight if weight is None else weight - R = self.model(input, target) - - R = R[0] + R = self.model(input, target)[0] J = jacobian(R, pg['params']) - J[0] = J[0] - J[1] = J[1] - self.last = self.loss = self.loss if hasattr(self, 'loss') \ - else self.model.loss(input, target) - # torch.cuda.nvtx.range_push("JTJc") + self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) + J0wp = torchbsr2wp(J[0]) - J0twp = sparse.bsr_transposed(J0wp) - U = sparse.bsr_mm(J0twp, J0wp) - # torch.cuda.nvtx.range_pop() - # J0D = J[0].to_dense() - # UD = U.to_dense() - # torch.testing.assert_close(UD, J0D.mT @ J0D) - # del J0D - # del UD - # torch.cuda.nvtx.range_push("JTJp") J1wp = torchbsr2wp(J[1]) + J0twp = sparse.bsr_transposed(J0wp) J1twp = sparse.bsr_transposed(J1wp) + U = sparse.bsr_mm(J0twp, J0wp) V = sparse.bsr_mm(J1twp, J1wp) - # torch.cuda.nvtx.range_pop() - # J1D = J[1].to_dense() - # VD = V.to_dense() - # torch.testing.assert_close(VD, J1D.mT @ J1D) - # del J1D - # del VD - - # torch.cuda.nvtx.range_push("Clamp") + + if self.matrix_free_normal: + del J0twp, J1twp + else: + W = sparse.bsr_mm(J0twp, J1wp) + Wt = sparse.bsr_transposed(W) + del J0twp, J1twp + Upt = wp2torchbsr(U) Vpt = wp2torchbsr(V) diagonal_op_(Upt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) diagonal_op_(Vpt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) - # torch.cuda.nvtx.range_pop() + R_flat = R.reshape(-1).contiguous() + Rwp = format_vec_for_bsr(R_flat, (J0wp.block_shape[1], J0wp.block_shape[0])) + Ic = sparse.bsr_mv(J0wp, Rwp, alpha=-1.0, transpose=True) + Ip = sparse.bsr_mv(J1wp, Rwp, alpha=-1.0, transpose=True) + rhs_c = wp.empty_like(Ic) + rhs_p = wp.empty_like(Ip) + scratch_pts2 = wp.empty_like(Ip) + + if self.matrix_free_normal: + scratch_obs = wp.empty_like(Rwp) + scratch_pts = wp.empty_like(Ip) + + solver_tol = getattr(self.solver, "tol", None) + solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 while self.last <= self.loss: - damping = pg['damping'] - R = R.reshape(-1) - - # torch.cuda.nvtx.range_push("Damp") - # damp = lambda x: x.pow(2) * damping + x - damp = partial(torch.mul, other=1+damping) + damp = partial(torch.mul, other=1+pg['damping']) diagonal_op_(Upt, op=damp) diagonal_op_(Vpt, op=damp) - # sparse.bsr_set_diag(U, sparse.bsr_get_diag(U) * (1+pg['damping'])) - # sparse.bsr_set_diag(V, sparse.bsr_get_diag(V) * (1+pg['damping'])) - # torch.cuda.nvtx.range_pop() - - # torch.cuda.nvtx.range_push("W") - W = J0twp @ J1wp - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("Ic") - Rwp = format_vec_for_bsr(R, J0twp.block_shape) - Ic = sparse.bsr_mv(J0twp, Rwp, alpha=-1.0) - Ip = sparse.bsr_mv(J1twp, Rwp, alpha=-1.0) - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("Inv") + V_i = torchbsr2wp(inv_op(Vpt)) - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("WVi") - WV_i = W @ V_i - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("rhs1") - rhs = sparse.bsr_mv(WV_i, Ip, y=Ic, alpha=-1.0, beta=1.0) - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("lhs1") - Wt = W.transpose() - lhs = sparse.bsr_axpy(U, WV_i @ Wt, alpha=1.0, beta=-1.0) # this matrix is NOT symetric - # torch.cuda.nvtx.range_pop() - D_c = wp.zeros_like(rhs) - # torch.cuda.nvtx.range_push("Solve C") - solver_tol = getattr(self.solver, "tol", None) - solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 - results = linear.cg( - A=lhs, - b=rhs, + + if self.matrix_free_normal: + def schur_matvec(x, y, z, alpha, beta, _V_i=V_i): + sparse.bsr_mv(J0wp, x, y=scratch_obs, beta=0.0) + sparse.bsr_mv(J1wp, scratch_obs, y=scratch_pts, beta=0.0, transpose=True) + sparse.bsr_mv(_V_i, scratch_pts, y=scratch_pts2, beta=0.0) + sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0) + if z.ptr != y.ptr and beta != 0.0: + wp.copy(src=y, dest=z) + sparse.bsr_mv(J0wp, scratch_obs, y=z, alpha=-alpha, beta=beta, transpose=True) + sparse.bsr_mv(U, x, y=z, alpha=alpha, beta=1.0) + + schur_op = linear.LinearOperator( + shape=U.shape, dtype=U.values.dtype, device=U.device, + matvec=schur_matvec, + ) + schur_M = linear.preconditioner(U) + + wp.copy(src=Ic, dest=rhs_c) + sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0) + sparse.bsr_mv(J0wp, scratch_obs, y=rhs_c, alpha=-1.0, beta=1.0, transpose=True) + else: + WV_i = sparse.bsr_mm(W, V_i) + WVi_Wt = sparse.bsr_mm(WV_i, Wt) + U_clone_torch = torch.sparse_bsr_tensor( + crow_indices=Upt.crow_indices().clone(), + col_indices=Upt.col_indices().clone(), + values=Upt.values().clone(), + size=Upt.shape, device=Upt.device, dtype=Upt.dtype, + ) + schur_op = sparse.bsr_axpy(WVi_Wt, torchbsr2wp(U_clone_torch), alpha=-1.0) + schur_M = linear.preconditioner(schur_op) + wp.copy(src=Ic, dest=rhs_c) + sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse.bsr_mv(W, scratch_pts2, y=rhs_c, alpha=-1.0, beta=1.0) + + D_c = wp.zeros_like(rhs_c) + linear.cg( + A=schur_op, + b=rhs_c, x=D_c, tol=solver_tol, maxiter=solver_maxiter, - M=linear.preconditioner(lhs), + M=schur_M, ) - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("rhs2") + wp.copy(src=Ip, dest=rhs_p) - rhs = sparse.bsr_mv(Wt, D_c, alpha=-1.0, beta=1.0, y=Ip) # rhs = Ip - Wt @ D_c - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("solve2") - lhs = V - D_p = wp.zeros_like(rhs) - results = linear.cg( - A=lhs, - b=rhs, + if self.matrix_free_normal: + sparse.bsr_mv(J0wp, D_c, y=scratch_obs, beta=0.0) + sparse.bsr_mv(J1wp, scratch_obs, y=rhs_p, + alpha=-1.0, beta=1.0, transpose=True) + else: + sparse.bsr_mv(Wt, D_c, y=rhs_p, alpha=-1.0, beta=1.0) + + D_p = wp.zeros_like(rhs_p) + linear.cg( + A=V, + b=rhs_p, x=D_p, tol=solver_tol, maxiter=solver_maxiter, - M=linear.preconditioner(lhs), + M=linear.preconditioner(V), ) - # torch.cuda.nvtx.range_pop() - # torch.cuda.nvtx.range_push("Update") - D_c = wp.to_torch(D_c).flatten() - D_p = wp.to_torch(D_p).flatten() - D = torch.cat([D_c, D_p]) + + D_c_t = wp.to_torch(D_c).flatten() + D_p_t = wp.to_torch(D_p).flatten() + D = torch.cat([D_c_t, D_p_t]) self.update_parameter(pg['params'], D) - # torch.cuda.nvtx.range_pop() self.loss = self.model.loss(input, target) print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) - # torch.cuda.nvtx.range_push("Strategy") - # self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1)) - # Pass Warp-format Jacobians as well so strategies can do bsrmv without - # hitting PyTorch's CUDA BSR matvec limitation for rectangular blocks. + self.strategy.update( pg, last=self.last, loss=self.loss, J=J, Jwp=[J0wp, J1wp], - D=[D_c, D_p], - R=R.view(-1, 1), + D=[D_c_t, D_p_t], + R=R_flat.view(-1, 1), ) - # torch.cuda.nvtx.range_pop() - if self.last < self.loss and self.reject_count < self.reject: # reject step - self.update_parameter(params = pg['params'], step = -D) + + if self.last < self.loss and self.reject_count < self.reject: # reject step + self.update_parameter(params=pg['params'], step=-D) self.loss, self.reject_count = self.last, self.reject_count + 1 else: break + return self.loss + From f64d00b83fa0c8db3d6883fbef8aa7a3887d0ec2 Mon Sep 17 00:00:00 2001 From: Seokwoo Park Date: Tue, 28 Apr 2026 22:48:35 -0400 Subject: [PATCH 16/28] Version up to 0.2.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c2bd099..8693df8 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension -VERSION = "0.2" +VERSION = "0.2.1" def readme(): """Read the README.md file for long description""" From 40798f1963d7bd68d0f15365615483a1a2996925 Mon Sep 17 00:00:00 2001 From: SEOKWOOPARK Date: Sat, 23 May 2026 19:36:32 +0000 Subject: [PATCH 17/28] Fix deprecated function in Warp --- bae/sparse/warp_wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bae/sparse/warp_wrappers.py b/bae/sparse/warp_wrappers.py index d5aa5fa..49598b0 100644 --- a/bae/sparse/warp_wrappers.py +++ b/bae/sparse/warp_wrappers.py @@ -9,7 +9,7 @@ def torchbsr2wp(tbsr): assert tbsr.layout == torch.sparse_bsr - block_type = wp.mat(shape=tbsr.values().shape[-2:], dtype=wp.dtype_from_torch(tbsr.dtype)) + block_type = wp.types.matrix(shape=tbsr.values().shape[-2:], dtype=wp.dtype_from_torch(tbsr.dtype)) bsr = wps.bsr_matrix_t(block_type)() bsr.nrow = int(tbsr.shape[0] // block_type._shape_[0]) bsr.ncol = int(tbsr.shape[1] // block_type._shape_[1]) @@ -57,7 +57,7 @@ def wp2torchbsr(bsr): def format_vec_for_bsr(tvec, block_shape): y_vec_len = block_shape[1] - y_dtype = wp.vec(length=y_vec_len, dtype=wp.dtype_from_torch(tvec.dtype)) + y_dtype = wp.types.vector(length=y_vec_len, dtype=wp.dtype_from_torch(tvec.dtype)) if tvec.ndim == 1 and tvec.shape[-1] != y_vec_len: tvec = tvec.reshape(-1, y_vec_len) vwp = wp.from_torch(tvec, dtype=y_dtype) From 165104d8bbf5d6125c2a2dfebdb7f9cd0a537cd0 Mon Sep 17 00:00:00 2001 From: SEOKWOOPARK Date: Sat, 23 May 2026 22:17:13 +0000 Subject: [PATCH 18/28] Replace Warp with Triton kernels and adjust corresponding codes --- ba_example.py | 48 ++- bae/autograd/graph.py | 31 +- bae/optim/optimizer.py | 236 ++++++------ bae/optim/triton_kernel.py | 718 +++++++++++++++++++++++++++++++++++++ 4 files changed, 913 insertions(+), 120 deletions(-) create mode 100644 bae/optim/triton_kernel.py diff --git a/ba_example.py b/ba_example.py index 7b4f71a..3445fcc 100644 --- a/ba_example.py +++ b/ba_example.py @@ -4,15 +4,13 @@ import torch import pypose as pp import warp as wp -from warp import sparse as wpsparse - from ba_helpers import Reproj, least_square_error from bae.optim.optimizer import Schur +from bae.optim.triton_kernel import sparse_bsr_mv from datapipes.bal_loader import get_problem, read_bal_data from bae.sparse.py_ops import * from bae.optim import LM from bae.utils.pysolvers import PCG, CuDSS -from bae.sparse.warp_wrappers import format_vec_for_bsr import torch.nn as nn from bae.autograd.function import TrackingTensor, map_transform @@ -20,11 +18,14 @@ TARGET_DATASET = "trafalgar" TARGET_PROBLEM = "problem-257-65132-pre" -# other options: # TARGET_DATASET = "ladybug" # TARGET_PROBLEM = "problem-1723-156502-pre" # TARGET_DATASET = "dubrovnik" # TARGET_PROBLEM = "problem-356-226730-pre" +# TARGET_DATASET = "final" +# TARGET_PROBLEM = "problem-13682-4456117-pre" +# TARGET_DATASET = "venice" +# TARGET_PROBLEM = "problem-1778-993923-pre" DEVICE = "cuda" OPTIMIZE_INTRINSICS = True @@ -81,20 +82,16 @@ def update(self, pg, last, loss, J, D, R, *args, **kwargs): Jwp = kwargs.get("Jwp") if Jwp is not None: J = Jwp + JD = None + for i in range(len(D)): - if JD is None: - if Jwp is not None: - Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) - JD = wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() - else: - JD = J[i] @ D[i].flatten() + if Jwp is not None: + JD_i = sparse_bsr_mv(J[i], D[i].flatten().contiguous()).flatten() else: - if Jwp is not None: - Dwp = format_vec_for_bsr(D[i].flatten().contiguous(), J[i].block_shape) - JD += wp.to_torch(wpsparse.bsr_mv(J[i], Dwp)).flatten() - else: - JD += J[i] @ D[i].flatten() + JD_i = J[i] @ D[i].flatten() + JD = JD_i if JD is None else JD + JD_i + JD = JD[..., None] denom = -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() @@ -145,6 +142,8 @@ def main(): warp_device = None warp_mempool_start_current = None warp_mempool_start_high = None + total_memory = None + nontorch_baseline = None dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) dataset = { @@ -171,6 +170,8 @@ def main(): device=cuda_device, clear_history=True, ) + else: + print("CUDA is not available; skipping CUDA memory tracking.") if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"): try: @@ -190,7 +191,7 @@ def main(): strategy = TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) - optimizer = Schur(model, strategy=strategy, solver=solver, reject=30) + optimizer = Schur(model, strategy=strategy, solver=solver, reject=30, matrix_free_normal=True) print('Initial loss:', least_square_error( model.pose, @@ -203,6 +204,10 @@ def main(): if cuda_device is not None and torch.cuda.is_available(): torch.cuda.synchronize(cuda_device) torch.cuda.reset_peak_memory_stats(cuda_device) + free_baseline, total_memory = torch.cuda.mem_get_info(cuda_device) + torch_reserved_baseline = torch.cuda.memory_reserved(cuda_device) + warp_current_baseline = (wp.get_mempool_used_mem_current(warp_device) if warp_device is not None else 0) + nontorch_baseline = ((total_memory - free_baseline) - torch_reserved_baseline - warp_current_baseline) start = perf_counter() for idx in range(20): @@ -222,12 +227,23 @@ def main(): if cuda_device is not None and torch.cuda.is_available(): peak_allocated = torch.cuda.max_memory_allocated(cuda_device) + try: peak_reserved = torch.cuda.max_memory_reserved(cuda_device) except AttributeError: peak_reserved = torch.cuda.max_memory_cached(cuda_device) + + free_end, _ = torch.cuda.mem_get_info(cuda_device) + torch_reserved_end = torch.cuda.memory_reserved(cuda_device) + warp_current_end = (wp.get_mempool_used_mem_current(warp_device) if warp_device is not None else 0) + nontorch_end = ((total_memory - free_end) - torch_reserved_end - warp_current_end) + module_growth = nontorch_end - nontorch_baseline if nontorch_baseline is not None else None + print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") + print(f"Non-allocator CUDA memory (context + kernel modules): {_format_bytes(nontorch_end)}") + if module_growth is not None: + print(f"Kernel-module growth during run (Triton JIT binaries): {_format_bytes(module_growth)}") if warp_device is not None and warp_mempool_start_current is not None: try: diff --git a/bae/autograd/graph.py b/bae/autograd/graph.py index e041ee5..9f7599f 100644 --- a/bae/autograd/graph.py +++ b/bae/autograd/graph.py @@ -1,4 +1,3 @@ - from typing import Optional import pypose as pp @@ -180,8 +179,34 @@ def backward(output_): if len(argnums) == 0: warning("No upstream parameters to compute jacobian") return - with pp.retain_ltype(): - jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args) + + total_obs = args[0].shape[0] if len(args) > 0 else 0 + max_obs_per_chunk = 50_000 + if total_obs <= max_obs_per_chunk: + with pp.retain_ltype(): + jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args) + else: + chunks = [] + with pp.retain_ltype(): + for start in range(0, total_obs, max_obs_per_chunk): + end = min(start + max_obs_per_chunk, total_obs) + chunk_args = [a[start:end] for a in args] + chunks.append(list( + torch.vmap(jacrev(func, argnums=argnums))(*chunk_args) + )) + n_args = len(argnums) + jac_blocks = [] + + for i in range(n_args): + parts = [c[i] for c in chunks] + jac_blocks.append(torch.cat(parts, dim=0)) + del parts + + for c in chunks: + c[i] = None + + jac_blocks = tuple(jac_blocks) + del chunks for jacidx, argidx in enumerate(argnums): jac_block = jac_blocks[jacidx] arg = args[argidx] diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 44c6d82..f4a5f4d 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -1,11 +1,12 @@ import torch import pypose as pp -import warp as wp from functools import partial from pypose.optim import LevenbergMarquardt as ppLM -from warp import sparse -from warp.optim import linear -from bae.sparse.warp_wrappers import format_vec_for_bsr, torchbsr2wp, wp2torchbsr +from .triton_kernel import ( + sparse_bsr_mm, sparse_bsr_mv, + sparse_bsr_transposed, sparse_bsr_axpy, + BlockJacobi, cg, +) from ..autograd.graph import jacobian from ..autograd.function import TrackingTensor from ..sparse.py_ops import diagonal_op_, inv_op @@ -15,11 +16,38 @@ class LM(ppLM): - def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): + def __init__(self, *args, matrix_free_normal: bool = False, loss_chunk_size: int = 100_000, **kwargs): self.matrix_free_normal = matrix_free_normal + self.loss_chunk_size = loss_chunk_size super(LM, self).__init__(*args, **kwargs) self.mm = CuSparse() + def _chunked_model_loss(self, input, target=None): + m = self.model + if not isinstance(input, dict): + return m.loss(input, target) + obs_axis_keys = ("points_2d", "camera_indices", "point_indices") + if not all(k in input for k in obs_axis_keys): + return m.loss(input, target) + n = input["points_2d"].shape[0] + chunk = self.loss_chunk_size + if chunk <= 0 or n <= chunk: + return m.loss(input, target) + + total = None + for start in range(0, n, chunk): + end = min(start + chunk, n) + chunk_input = {k: input[k][start:end] for k in obs_axis_keys} + output = m.model_forward(chunk_input) + chunk_residuals = m.residuals(output, target) + if len(m.kernel) > 1: + parts = [k(r.square().sum(-1)).sum() for k, r in zip(m.kernel, chunk_residuals)] + else: + parts = [m.kernel[0](r.square().sum(-1)).sum() for r in chunk_residuals] + chunk_loss = sum(parts) + total = chunk_loss if total is None else total + chunk_loss + return total + @torch.no_grad() def step(self, input, target=None, weight=None): for pg in self.param_groups: @@ -33,7 +61,7 @@ def step(self, input, target=None, weight=None): J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1).to_sparse_csr() del J_list - self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) + self.last = self.loss = self.loss if hasattr(self, 'loss') else self._chunked_model_loss(input, target) self.reject_count = 0 if self.matrix_free_normal: @@ -60,7 +88,7 @@ def step(self, input, target=None, weight=None): print(e, "\nLinear solver failed. Breaking optimization step...") break self.update_parameter(pg['params'], D) - self.loss = self.model.loss(input, target) + self.loss = self._chunked_model_loss(input, target) print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1)) if self.last < self.loss and self.reject_count < self.reject: # reject step @@ -95,120 +123,126 @@ def step(self, input, target=None, weight=None): weight = self.weight if weight is None else weight R = self.model(input, target)[0] J = jacobian(R, pg['params']) + + if isinstance(R, TrackingTensor): + R = R.tensor() + else: + R = R.detach() + torch.cuda.empty_cache() - self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) - - J0wp = torchbsr2wp(J[0]) - J1wp = torchbsr2wp(J[1]) - J0twp = sparse.bsr_transposed(J0wp) - J1twp = sparse.bsr_transposed(J1wp) - U = sparse.bsr_mm(J0twp, J0wp) - V = sparse.bsr_mm(J1twp, J1wp) + self.last = self.loss = self.loss if hasattr(self, 'loss') else self._chunked_model_loss(input, target) + J0 = J[0] + J1 = J[1] if self.matrix_free_normal: - del J0twp, J1twp + J0t = sparse_bsr_transposed(J0) + U = sparse_bsr_mm(J0t, J0) + del J0t + J1t = sparse_bsr_transposed(J1) + V = sparse_bsr_mm(J1t, J1) + del J1t else: - W = sparse.bsr_mm(J0twp, J1wp) - Wt = sparse.bsr_transposed(W) - del J0twp, J1twp - - Upt = wp2torchbsr(U) - Vpt = wp2torchbsr(V) - diagonal_op_(Upt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) - diagonal_op_(Vpt, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + J0t = sparse_bsr_transposed(J0) + J1t = sparse_bsr_transposed(J1) + U = sparse_bsr_mm(J0t, J0) + V = sparse_bsr_mm(J1t, J1) + W = sparse_bsr_mm(J0t, J1) + Wt = sparse_bsr_transposed(W) + del J0t, J1t + + diagonal_op_(U, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + diagonal_op_(V, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) R_flat = R.reshape(-1).contiguous() - Rwp = format_vec_for_bsr(R_flat, (J0wp.block_shape[1], J0wp.block_shape[0])) - Ic = sparse.bsr_mv(J0wp, Rwp, alpha=-1.0, transpose=True) - Ip = sparse.bsr_mv(J1wp, Rwp, alpha=-1.0, transpose=True) - rhs_c = wp.empty_like(Ic) - rhs_p = wp.empty_like(Ip) - scratch_pts2 = wp.empty_like(Ip) + Ic = sparse_bsr_mv(J0, R_flat, alpha=-1.0, transpose=True) + Ip = sparse_bsr_mv(J1, R_flat, alpha=-1.0, transpose=True) + rhs_c = torch.empty_like(Ic) + rhs_p = torch.empty_like(Ip) + scratch_pts2 = torch.empty_like(Ip) + schur_Ap_buf = torch.empty_like(Ic) + v_Ap_buf = torch.empty_like(Ip) + D_c = torch.empty_like(Ic) + D_p = torch.empty_like(Ip) + cg_r_buf_c = torch.empty_like(Ic) + cg_p_buf_c = torch.empty_like(Ic) + cg_r_buf_p = torch.empty_like(Ip) + cg_p_buf_p = torch.empty_like(Ip) if self.matrix_free_normal: - scratch_obs = wp.empty_like(Rwp) - scratch_pts = wp.empty_like(Ip) + scratch_obs = torch.empty_like(R_flat) + scratch_pts = torch.empty_like(Ip) + z_buf = torch.empty_like(Ic) - solver_tol = getattr(self.solver, "tol", None) + solver_tol = getattr(self.solver, "tol", None) or 1e-5 solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 + mm_cache_WV_i = {} if not self.matrix_free_normal else None + mm_cache_WVi_Wt = {} if not self.matrix_free_normal else None + axpy_cache_schur = {} if not self.matrix_free_normal else None + v_M = BlockJacobi(V) + schur_M = BlockJacobi(U) if self.matrix_free_normal else None while self.last <= self.loss: damp = partial(torch.mul, other=1+pg['damping']) - diagonal_op_(Upt, op=damp) - diagonal_op_(Vpt, op=damp) - - V_i = torchbsr2wp(inv_op(Vpt)) + diagonal_op_(U, op=damp) + diagonal_op_(V, op=damp) + V_i = inv_op(V) if self.matrix_free_normal: - def schur_matvec(x, y, z, alpha, beta, _V_i=V_i): - sparse.bsr_mv(J0wp, x, y=scratch_obs, beta=0.0) - sparse.bsr_mv(J1wp, scratch_obs, y=scratch_pts, beta=0.0, transpose=True) - sparse.bsr_mv(_V_i, scratch_pts, y=scratch_pts2, beta=0.0) - sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0) - if z.ptr != y.ptr and beta != 0.0: - wp.copy(src=y, dest=z) - sparse.bsr_mv(J0wp, scratch_obs, y=z, alpha=-alpha, beta=beta, transpose=True) - sparse.bsr_mv(U, x, y=z, alpha=alpha, beta=1.0) - - schur_op = linear.LinearOperator( - shape=U.shape, dtype=U.values.dtype, device=U.device, - matvec=schur_matvec, - ) - schur_M = linear.preconditioner(U) - - wp.copy(src=Ic, dest=rhs_c) - sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) - sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0) - sparse.bsr_mv(J0wp, scratch_obs, y=rhs_c, alpha=-1.0, beta=1.0, transpose=True) + def schur_matvec(p, _V_i=V_i, _z=schur_Ap_buf): + sparse_bsr_mv(J0, p, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J1, scratch_obs, y=scratch_pts, beta=0.0, transpose=True) + sparse_bsr_mv(_V_i, scratch_pts, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(J1, scratch_pts2, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J0, scratch_obs, y=_z, alpha=-1.0, beta=0.0, transpose=True) + sparse_bsr_mv(U, p, y=_z, alpha=1.0, beta=1.0) + return _z + + matvec_fn = schur_matvec + schur_M.refresh(U) + rhs_c.copy_(Ic) + sparse_bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(J1, scratch_pts2, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J0, scratch_obs, y=rhs_c, alpha=-1.0, beta=1.0, transpose=True) else: - WV_i = sparse.bsr_mm(W, V_i) - WVi_Wt = sparse.bsr_mm(WV_i, Wt) - U_clone_torch = torch.sparse_bsr_tensor( - crow_indices=Upt.crow_indices().clone(), - col_indices=Upt.col_indices().clone(), - values=Upt.values().clone(), - size=Upt.shape, device=Upt.device, dtype=Upt.dtype, - ) - schur_op = sparse.bsr_axpy(WVi_Wt, torchbsr2wp(U_clone_torch), alpha=-1.0) - schur_M = linear.preconditioner(schur_op) - wp.copy(src=Ic, dest=rhs_c) - sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) - sparse.bsr_mv(W, scratch_pts2, y=rhs_c, alpha=-1.0, beta=1.0) - - D_c = wp.zeros_like(rhs_c) - linear.cg( - A=schur_op, - b=rhs_c, - x=D_c, - tol=solver_tol, - maxiter=solver_maxiter, - M=schur_M, - ) + WV_i = sparse_bsr_mm(W, V_i, topology_cache=mm_cache_WV_i) + WVi_Wt = sparse_bsr_mm(WV_i, Wt, topology_cache=mm_cache_WVi_Wt) + del WV_i + schur_op = sparse_bsr_axpy(WVi_Wt, U, alpha=-1.0, + topology_cache=axpy_cache_schur) + del WVi_Wt + matvec_fn = lambda p, _S=schur_op, _y=schur_Ap_buf: \ + sparse_bsr_mv(_S, p, y=_y, beta=0.0) + if schur_M is None: + schur_M = BlockJacobi(schur_op) + else: + schur_M.refresh(schur_op) + rhs_c.copy_(Ic) + sparse_bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(W, scratch_pts2, y=rhs_c, alpha=-1.0, beta=1.0) - - wp.copy(src=Ip, dest=rhs_p) - + D_c.zero_() + cg(matvec_fn, rhs_c, x=D_c, M=schur_M, + tol=solver_tol, maxiter=solver_maxiter, + r_buf=cg_r_buf_c, p_buf=cg_p_buf_c) + + rhs_p.copy_(Ip) if self.matrix_free_normal: - sparse.bsr_mv(J0wp, D_c, y=scratch_obs, beta=0.0) - sparse.bsr_mv(J1wp, scratch_obs, y=rhs_p, - alpha=-1.0, beta=1.0, transpose=True) + sparse_bsr_mv(J0, D_c, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J1, scratch_obs, y=rhs_p, alpha=-1.0, beta=1.0, transpose=True) else: - sparse.bsr_mv(Wt, D_c, y=rhs_p, alpha=-1.0, beta=1.0) - - D_p = wp.zeros_like(rhs_p) - linear.cg( - A=V, - b=rhs_p, - x=D_p, - tol=solver_tol, - maxiter=solver_maxiter, - M=linear.preconditioner(V), - ) + sparse_bsr_mv(Wt, D_c, y=rhs_p, alpha=-1.0, beta=1.0) + + v_M.refresh(V) + D_p.zero_() + cg(lambda p, _V=V, _y=v_Ap_buf: sparse_bsr_mv(_V, p, y=_y, beta=0.0), + rhs_p, x=D_p, M=v_M, + tol=solver_tol, maxiter=solver_maxiter, + r_buf=cg_r_buf_p, p_buf=cg_p_buf_p) - D_c_t = wp.to_torch(D_c).flatten() - D_p_t = wp.to_torch(D_p).flatten() + D_c_t = D_c.flatten() + D_p_t = D_p.flatten() D = torch.cat([D_c_t, D_p_t]) self.update_parameter(pg['params'], D) - self.loss = self.model.loss(input, target) + self.loss = self._chunked_model_loss(input, target) print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) self.strategy.update( @@ -216,7 +250,7 @@ def schur_matvec(x, y, z, alpha, beta, _V_i=V_i): last=self.last, loss=self.loss, J=J, - Jwp=[J0wp, J1wp], + Jwp=[J0, J1], D=[D_c_t, D_p_t], R=R_flat.view(-1, 1), ) diff --git a/bae/optim/triton_kernel.py b/bae/optim/triton_kernel.py new file mode 100644 index 0000000..70d0ae8 --- /dev/null +++ b/bae/optim/triton_kernel.py @@ -0,0 +1,718 @@ +from typing import Optional, Tuple +import torch +import triton +import triton.language as tl + + +def _bsr_to_torch( + A, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Tuple[int, int], Tuple[int, int], torch.dtype, torch.device]: + + if isinstance(A, torch.Tensor) and A.layout == torch.sparse_bsr: + crow = A.crow_indices() + col = A.col_indices() + values = A.values() + BR, BC = values.shape[-2], values.shape[-1] + nrow = A.shape[0] // BR + ncol = A.shape[1] // BC + return crow, col, values, (BR, BC), (nrow, ncol), values.dtype, values.device + + raise TypeError(f"Unsupported BSR matrix type: {type(A)}") + + +def _flatten_vec(v: torch.Tensor) -> torch.Tensor: + if not v.is_contiguous(): + raise ValueError("Vector must be contiguous to share memory with Triton kernel") + return v.view(-1) + + +def _build_bsr_output(crow, col, vals, nrow, ncol, BR, BC): + return torch.sparse_bsr_tensor( + crow_indices=crow.to(torch.int32), + col_indices=col.to(torch.int32), + values=vals, + size=(nrow * BR, ncol * BC), + ) + + +_TRITON_DTYPES = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.float64: tl.float64, + torch.bfloat16: tl.bfloat16, +} + + +def _triton_dtype(t: torch.dtype): + try: + return _TRITON_DTYPES[t] + except KeyError as exc: + raise TypeError(f"Unsupported dtype for Triton BSR kernels: {t}") from exc + + +def _next_pow2(n: int) -> int: + return 1 << max(0, (int(n) - 1)).bit_length() + + +def _uncompress_rows(crow: torch.Tensor, nnz: int) -> torch.Tensor: + nrow = crow.numel() - 1 + + if nnz == 0: + return torch.empty(0, dtype=torch.int32, device=crow.device) + + counts = (crow[1:] - crow[:-1]).to(torch.int64) + + return torch.repeat_interleave(torch.arange(nrow, device=crow.device, dtype=torch.int32), counts) + + +@triton.jit +def _bsr_mv_fwd_kernel( + A_crow, A_col, A_val, + X, Y, + alpha, beta, + nrow, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, + DTYPE: tl.constexpr, +): + + row = tl.program_id(0).to(tl.int32) + if row >= nrow: + return + + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + + y_off = row * BR + r_idx + if beta == 0.0: + acc = tl.zeros((BLK_R,), dtype=DTYPE) + else: + prev = tl.load(Y + y_off, mask=rmask, other=0.0).to(DTYPE) + acc = beta * prev + + if alpha != 0.0: + beg = tl.load(A_crow + row).to(tl.int32) + end = tl.load(A_crow + row + 1).to(tl.int32) + block_size = BR * BC + partial = tl.zeros((BLK_R,), dtype=DTYPE) + n = end - beg + for i in tl.range(0, n): + blk = beg + i + col = tl.load(A_col + blk).to(tl.int32) + v_off = blk * block_size + r_idx[:, None] * BC + c_idx[None, :] + block_vals = tl.load( + A_val + v_off, + mask=rmask[:, None] & cmask[None, :], other=0.0, + ).to(DTYPE) + x_off = col * BC + c_idx + x_vals = tl.load(X + x_off, mask=cmask, other=0.0).to(DTYPE) + partial += tl.sum(block_vals * x_vals[None, :], axis=1) + acc += alpha * partial + + tl.store(Y + y_off, acc, mask=rmask) + + +@triton.jit +def _bsr_mv_trans_kernel( + A_crow, A_col, A_val, + X, Y, + alpha, + nrow, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, + DTYPE: tl.constexpr, +): + + row = tl.program_id(0).to(tl.int32) + + if row >= nrow: + return + + if alpha == 0.0: + return + + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + x_off = row * BR + r_idx + x_vals = tl.load(X + x_off, mask=rmask, other=0.0).to(DTYPE) + beg = tl.load(A_crow + row).to(tl.int32) + end = tl.load(A_crow + row + 1).to(tl.int32) + block_size = BR * BC + n = end - beg + + for i in tl.range(0, n): + blk = beg + i + col = tl.load(A_col + blk).to(tl.int32) + v_off = blk * block_size + r_idx[:, None] * BC + c_idx[None, :] + block_vals = tl.load(A_val + v_off, mask=rmask[:, None] & cmask[None, :], other=0.0).to(DTYPE) + contrib = tl.sum(block_vals * x_vals[:, None], axis=0) * alpha + tl.atomic_add(Y + (col * BC + c_idx), contrib, mask=cmask) + + +@triton.jit +def _gather_transpose_kernel( + src_ptr, sort_idx_ptr, dst_ptr, + nnz, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, +): + i = tl.program_id(0).to(tl.int32) + if i >= nnz: + return + src_idx = tl.load(sort_idx_ptr + i).to(tl.int64) + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + + src_off = src_idx * (BR * BC) + r_idx[:, None] * BC + c_idx[None, :] + block = tl.load( + src_ptr + src_off, + mask=rmask[:, None] & cmask[None, :], other=0.0, + ) + i64 = i.to(tl.int64) + dst_off = i64 * (BC * BR) + c_idx[None, :] * BR + r_idx[:, None] + tl.store(dst_ptr + dst_off, block, mask=rmask[:, None] & cmask[None, :]) + + +@triton.jit +def _bsr_scale_kernel(Y, beta, n, BLOCK: tl.constexpr, DTYPE: tl.constexpr): + pid = tl.program_id(0) + off = pid * BLOCK + tl.arange(0, BLOCK) + mask = off < n + + if beta == 0.0: + tl.store(Y + off, tl.zeros((BLOCK,), dtype=DTYPE), mask=mask) + else: + v = tl.load(Y + off, mask=mask, other=0.0).to(DTYPE) + tl.store(Y + off, beta * v, mask=mask) + + +@triton.jit +def _bsr_mm_numeric_kernel( + A_crow, A_col, A_val, + B_crow, B_col, B_val, + C_blk_row, C_col, C_val, + alpha, + BR_A: tl.constexpr, BC_A: tl.constexpr, BC_B: tl.constexpr, + BLK_R: tl.constexpr, BLK_K: tl.constexpr, BLK_C: tl.constexpr, + BSEARCH_ITERS: tl.constexpr, + DTYPE: tl.constexpr, +): + + c_blk = tl.program_id(0).to(tl.int32) + c_row = tl.load(C_blk_row + c_blk).to(tl.int32) + c_col = tl.load(C_col + c_blk).to(tl.int32) + r_idx = tl.arange(0, BLK_R) + k_idx = tl.arange(0, BLK_K) + s_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR_A + kmask = k_idx < BC_A + smask = s_idx < BC_B + a_block_size = BR_A * BC_A + b_block_size = BC_A * BC_B + c_block_size = BR_A * BC_B + contrib = tl.zeros((BLK_R, BLK_C), dtype=DTYPE) + + if alpha != 0.0: + a_beg = tl.load(A_crow + c_row).to(tl.int32) + a_end = tl.load(A_crow + c_row + 1).to(tl.int32) + n_a = a_end - a_beg + for i in tl.range(0, n_a): + a_blk = a_beg + i + k = tl.load(A_col + a_blk).to(tl.int32) + b_beg = tl.load(B_crow + k).to(tl.int32) + b_end = tl.load(B_crow + k + 1).to(tl.int32) + lo = b_beg + hi = b_end + + for _ in tl.range(0, BSEARCH_ITERS): + mid = (lo + hi) // 2 + cond = lo < hi + safe_mid = tl.where(cond, mid, b_beg) + v = tl.load(B_col + safe_mid).to(tl.int32) + go_right = cond & (v < c_col) + go_left = cond & (v >= c_col) + lo = tl.where(go_right, mid + 1, lo) + hi = tl.where(go_left, mid, hi) + + in_range = lo < b_end + safe_idx = tl.where(in_range, lo, b_beg) + cur = tl.load(B_col + safe_idx).to(tl.int32) + exists = in_range & (cur == c_col) + safe_b_blk = tl.where(exists, lo, 0) + a_off = a_blk * a_block_size + r_idx[:, None] * BC_A + k_idx[None, :] + a_block = tl.load(A_val + a_off, mask=rmask[:, None] & kmask[None, :], other=0.0).to(DTYPE) + b_off = safe_b_blk * b_block_size + k_idx[:, None] * BC_B + s_idx[None, :] + b_block = tl.load(B_val + b_off,mask=(kmask[:, None] & smask[None, :]) & exists, other=0.0).to(DTYPE) + contrib += tl.sum(a_block[:, :, None] * b_block[None, :, :], axis=1) + + v_off = c_blk * c_block_size + r_idx[:, None] * BC_B + s_idx[None, :] + prev = tl.load(C_val + v_off, mask=rmask[:, None] & smask[None, :], other=0.0).to(DTYPE) + tl.store(C_val + v_off, prev + alpha * contrib, mask=rmask[:, None] & smask[None, :]) + + +def _bsr_mm_topology( + A_crow: torch.Tensor, A_col: torch.Tensor, + B_crow: torch.Tensor, B_col: torch.Tensor, + nrow_C: int, ncol_C: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + + device = A_crow.device + A_crow_l = A_crow.to(torch.int64) + A_col_l = A_col.to(torch.int64) + B_crow_l = B_crow.to(torch.int64) + B_col_l = B_col.to(torch.int64) + a_nnz = A_col_l.numel() + empty = ( + torch.zeros(nrow_C + 1, dtype=torch.int32, device=device), + torch.empty(0, dtype=torch.int32, device=device), + ) + if a_nnz == 0 or B_col_l.numel() == 0: + return empty + + b_lens = B_crow_l[1:] - B_crow_l[:-1] + counts = b_lens[A_col_l] + a_blk_offsets = torch.zeros(a_nnz + 1, dtype=torch.int64, device=device) + a_blk_offsets[1:] = counts.cumsum(0) + total = int(a_blk_offsets[-1].item()) + + if total == 0: + return empty + + a_blk_per_t = torch.repeat_interleave(torch.arange(a_nnz, device=device, dtype=torch.int64), counts) + pos_in_row = ( + torch.arange(total, device=device, dtype=torch.int64) + - a_blk_offsets[a_blk_per_t] + ) + b_starts = B_crow_l[A_col_l] + del B_crow_l + triplet_j = B_col_l[b_starts[a_blk_per_t] + pos_in_row] + del b_starts, pos_in_row, B_col_l + + nrow_A = A_crow_l.numel() - 1 + a_row_per_blk = torch.repeat_interleave( + torch.arange(nrow_A, device=device, dtype=torch.int64), + A_crow_l[1:] - A_crow_l[:-1], + ) + triplet_i = a_row_per_blk[a_blk_per_t] + key = triplet_i * ncol_C + triplet_j + sorted_key, _ = torch.sort(key) + keep = torch.ones(total, dtype=torch.bool, device=device) + keep[1:] = sorted_key[1:] != sorted_key[:-1] + unique_key = sorted_key[keep] + out_i = (unique_key // ncol_C).to(torch.int64) + out_j = (unique_key % ncol_C).to(torch.int32) + counts_per_row = torch.bincount(out_i, minlength=nrow_C) + crow_l = torch.zeros(nrow_C + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts_per_row.cumsum(0) + + return crow_l.to(torch.int32), out_j + + +def sparse_bsr_mv( + A, + x, + y=None, + alpha: float = 1.0, + beta: float = 0.0, + transpose: bool = False, + work_buffer=None, +): + crow, col, vals3d, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(A) + + if transpose: + out_blocks, out_block_size = ncol, BC + in_blocks, in_block_size = nrow, BR + else: + out_blocks, out_block_size = nrow, BR + in_blocks, in_block_size = ncol, BC + + expected_y = out_blocks * out_block_size + vals3d = vals3d.contiguous() + crow_i = crow.to(torch.int32).contiguous() + col_i = col.to(torch.int32).contiguous() + + x_t = _flatten_vec(x) + + if x_t.numel() != in_blocks * in_block_size: + raise ValueError(f"x has {x_t.numel()} scalars, expected {in_blocks * in_block_size}") + + if x_t.dtype != dtype: + raise TypeError(f"x dtype {x_t.dtype} != A dtype {dtype}") + + return_obj = y + + if y is None: + y_t = torch.empty(expected_y, dtype=dtype, device=device) + return_obj = y_t + beta = 0.0 + else: + y_t = _flatten_vec(y) + if y_t.numel() != expected_y: + raise ValueError(f"y has {y_t.numel()} scalars, expected {expected_y}") + if y_t.dtype != dtype: + raise TypeError(f"y dtype {y_t.dtype} != A dtype {dtype}") + + if x_t.data_ptr() == y_t.data_ptr(): + if work_buffer is None: + x_t = y_t.clone() + else: + wb = _flatten_vec(work_buffer) + wb.copy_(y_t) + x_t = wb + + BLK_R = max(_next_pow2(BR), 1) + BLK_C = max(_next_pow2(BC), 1) + DTYPE = _triton_dtype(dtype) + + if transpose: + n = y_t.numel() + SCALE_BLOCK = 1024 + scale_grid = ((n + SCALE_BLOCK - 1) // SCALE_BLOCK,) + _bsr_scale_kernel[scale_grid](y_t, float(beta), n, BLOCK=SCALE_BLOCK, DTYPE=DTYPE) + + if alpha != 0.0: + _bsr_mv_trans_kernel[(nrow,)]( + crow_i, col_i, vals3d, + x_t, y_t, + float(alpha), + nrow, + BR=BR, BC=BC, BLK_R=BLK_R, BLK_C=BLK_C, DTYPE=DTYPE, + ) + else: + _bsr_mv_fwd_kernel[(nrow,)]( + crow_i, col_i, vals3d, + x_t, y_t, + float(alpha), float(beta), + nrow, + BR=BR, BC=BC, BLK_R=BLK_R, BLK_C=BLK_C, DTYPE=DTYPE, + ) + return return_obj + + +def sparse_bsr_mm(A, B, alpha: float = 1.0, *, topology_cache=None): + A_crow, A_col, A_val, (BR_A, BC_A), (nrow_A, ncol_A), dtype, device = _bsr_to_torch(A) + B_crow, B_col, B_val, (BR_B, BC_B), (nrow_B, ncol_B), dtype_b, device_b = _bsr_to_torch(B) + + if dtype != dtype_b: + raise ValueError(f"A and B dtypes differ: {dtype} vs {dtype_b}") + + if device != device_b: + raise ValueError(f"A and B on different devices: {device} vs {device_b}") + + if BC_A != BR_B: + raise ValueError(f"Block-shape mismatch: A.block_shape[1]={BC_A}, B.block_shape[0]={BR_B}") + + if ncol_A != nrow_B: + raise ValueError(f"Block-row/col mismatch: A.ncol={ncol_A}, B.nrow={nrow_B}") + + nrow_C, ncol_C = nrow_A, ncol_B + BR_C, BC_C = BR_A, BC_B + + A_crow_i = A_crow.to(torch.int32).contiguous() + A_col_i = A_col.to(torch.int32).contiguous() + B_crow_i = B_crow.to(torch.int32).contiguous() + B_col_i = B_col.to(torch.int32).contiguous() + + if alpha == 0.0 or A_col.numel() == 0 or B_col.numel() == 0: + crow = torch.zeros(nrow_C + 1, dtype=torch.int32, device=device) + col = torch.empty(0, dtype=torch.int32, device=device) + vals = torch.empty((0, BR_C, BC_C), dtype=dtype, device=device) + return _build_bsr_output(crow, col, vals, nrow_C, ncol_C, BR_C, BC_C) + + cached = topology_cache is not None and "C_crow" in topology_cache + if cached: + C_crow = topology_cache["C_crow"] + C_col = topology_cache["C_col"] + C_blk_row = topology_cache["C_blk_row"] + else: + C_crow, C_col = _bsr_mm_topology(A_crow_i, A_col_i, B_crow_i, B_col_i, nrow_C, ncol_C) + out_nnz_tmp = int(C_col.numel()) + C_blk_row = _uncompress_rows(C_crow, out_nnz_tmp) + if topology_cache is not None: + topology_cache["C_crow"] = C_crow + topology_cache["C_col"] = C_col + topology_cache["C_blk_row"] = C_blk_row + + out_nnz = int(C_col.numel()) + C_vals = torch.zeros((out_nnz, BR_C, BC_C), dtype=dtype, device=device) + if out_nnz == 0: + return _build_bsr_output(C_crow, C_col, C_vals, nrow_C, ncol_C, BR_C, BC_C) + + BLK_R = max(_next_pow2(BR_C), 1) + BLK_K = max(_next_pow2(BC_A), 1) + BLK_C = max(_next_pow2(BC_C), 1) + DTYPE = _triton_dtype(dtype) + BSEARCH_ITERS = max(1, (max(1, ncol_C) - 1).bit_length() + 1) + + A_val_c = A_val.contiguous() + B_val_c = B_val.contiguous() + + _bsr_mm_numeric_kernel[(out_nnz,)]( + A_crow_i, A_col_i, A_val_c, + B_crow_i, B_col_i, B_val_c, + C_blk_row, C_col, C_vals, + float(alpha), + BR_A=BR_A, BC_A=BC_A, BC_B=BC_B, + BLK_R=BLK_R, BLK_K=BLK_K, BLK_C=BLK_C, + BSEARCH_ITERS=BSEARCH_ITERS, + DTYPE=DTYPE, + ) + return _build_bsr_output(C_crow, C_col, C_vals, nrow_C, ncol_C, BR_C, BC_C) + + +def sparse_bsr_transposed(A): + crow, col, vals3d, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(A) + + crow_i = crow.to(torch.int32).contiguous() + col_i = col.to(torch.int32).contiguous() + nnz = int(col_i.numel()) + + if nnz == 0: + new_crow = torch.zeros(ncol + 1, dtype=torch.int32, device=device) + new_col = torch.empty(0, dtype=torch.int32, device=device) + new_vals = torch.empty((0, BC, BR), dtype=dtype, device=device) + return _build_bsr_output(new_crow, new_col, new_vals, ncol, nrow, BC, BR) + + old_row = _uncompress_rows(crow_i, nnz).to(torch.int64) + new_row = col_i.to(torch.int64) + new_col = old_row + + new_nrow, new_ncol = ncol, nrow + key = new_row * new_ncol + new_col + _, sort_idx = torch.sort(key, stable=True) + + sorted_new_row = new_row[sort_idx] + sorted_new_col = new_col[sort_idx].to(torch.int32) + vals3d_c = vals3d.contiguous() + permuted_vals = torch.empty((nnz, BC, BR), dtype=dtype, device=device) + sort_idx64 = sort_idx.to(torch.int64).contiguous() + BLK_R_ = max(_next_pow2(BR), 1) + BLK_C_ = max(_next_pow2(BC), 1) + + _gather_transpose_kernel[(nnz,)]( + vals3d_c, sort_idx64, permuted_vals, + nnz, BR=BR, BC=BC, BLK_R=BLK_R_, BLK_C=BLK_C_, + ) + + counts = torch.bincount(sorted_new_row, minlength=new_nrow) + crow_l = torch.zeros(new_nrow + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts.cumsum(0) + new_crow = crow_l.to(torch.int32) + + return _build_bsr_output(new_crow, sorted_new_col, permuted_vals, new_nrow, new_ncol, BC, BR) + + +def sparse_bsr_axpy(x, y, alpha: float = 1.0, *, topology_cache=None): + x_crow, x_col, x_vals, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(x) + y_crow, y_col, y_vals, (BR_y, BC_y), (nrow_y, ncol_y), dtype_y, device_y = _bsr_to_torch(y) + + if (BR, BC) != (BR_y, BC_y): + raise ValueError(f"Block shapes differ: {(BR, BC)} vs {(BR_y, BC_y)}") + if (nrow, ncol) != (nrow_y, ncol_y): + raise ValueError(f"Block-matrix shapes differ: {(nrow, ncol)} vs {(nrow_y, ncol_y)}") + if dtype != dtype_y: + raise ValueError(f"Dtypes differ: {dtype} vs {dtype_y}") + if device != device_y: + raise ValueError(f"Devices differ: {device} vs {device_y}") + + x_crow_i = x_crow.to(torch.int32).contiguous() + x_col_i = x_col.to(torch.int32).contiguous() + y_crow_i = y_crow.to(torch.int32).contiguous() + y_col_i = y_col.to(torch.int32).contiguous() + x_nnz = int(x_col_i.numel()) + y_nnz = int(y_col_i.numel()) + x_vals_c = x_vals.contiguous() + y_vals_c = y_vals.contiguous() + + if x_nnz == 0: + return _build_bsr_output( + y_crow_i.clone(), y_col_i.clone(), y_vals_c.clone(), + nrow, ncol, BR, BC, + ) + if y_nnz == 0: + return _build_bsr_output( + x_crow_i.clone(), x_col_i.clone(), + (alpha * x_vals_c).contiguous(), nrow, ncol, BR, BC, + ) + + cached = topology_cache is not None and "out_crow" in topology_cache + + if cached: + out_crow = topology_cache["out_crow"] + out_col = topology_cache["out_col"] + x_target = topology_cache["x_target"] + x_pick = topology_cache["x_pick"] + y_target = topology_cache["y_target"] + y_pick = topology_cache["y_pick"] + out_nnz = int(out_col.numel()) + else: + x_block_row = _uncompress_rows(x_crow_i, x_nnz).to(torch.int64) + y_block_row = _uncompress_rows(y_crow_i, y_nnz).to(torch.int64) + + all_row = torch.cat([x_block_row, y_block_row]) + all_col = torch.cat([x_col_i.to(torch.int64), y_col_i.to(torch.int64)]) + all_src = torch.cat([ + torch.zeros(x_nnz, dtype=torch.int8, device=device), + torch.ones(y_nnz, dtype=torch.int8, device=device), + ]) + all_idx = torch.cat([ + torch.arange(x_nnz, dtype=torch.int64, device=device), + torch.arange(y_nnz, dtype=torch.int64, device=device), + ]) + + key = all_row * ncol + all_col + _, sort_idx = torch.sort(key, stable=True) + sorted_key = key[sort_idx] + sorted_src = all_src[sort_idx] + sorted_idx = all_idx[sort_idx] + keep = torch.ones(sorted_key.numel(), dtype=torch.bool, device=device) + keep[1:] = sorted_key[1:] != sorted_key[:-1] + unique_key = sorted_key[keep] + out_nnz = int(unique_key.numel()) + + out_row = (unique_key // ncol).to(torch.int64) + out_col = (unique_key % ncol).to(torch.int32) + output_index = keep.long().cumsum(0) - 1 + + x_mask = sorted_src == 0 + y_mask = ~x_mask + x_pick = sorted_idx[x_mask].contiguous() + x_target = output_index[x_mask].contiguous() + y_pick = sorted_idx[y_mask].contiguous() + y_target = output_index[y_mask].contiguous() + + counts = torch.bincount(out_row, minlength=nrow) + crow_l = torch.zeros(nrow + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts.cumsum(0) + out_crow = crow_l.to(torch.int32) + + if topology_cache is not None: + topology_cache["out_crow"] = out_crow + topology_cache["out_col"] = out_col + topology_cache["x_target"] = x_target + topology_cache["x_pick"] = x_pick + topology_cache["y_target"] = y_target + topology_cache["y_pick"] = y_pick + + out_vals = torch.zeros((out_nnz, BR, BC), dtype=dtype, device=device) + if x_pick.numel() > 0: + out_vals.index_add_(0, x_target, alpha * x_vals_c[x_pick]) + if y_pick.numel() > 0: + out_vals.index_add_(0, y_target, y_vals_c[y_pick]) + + return _build_bsr_output(out_crow, out_col, out_vals, nrow, ncol, BR, BC) + + +class BlockJacobi: + def __init__(self, A, ridge: float = 1e-12): + crow, col, vals3d, (BR, BC), (nrow, _), dtype, device = _bsr_to_torch(A) + if BR != BC: + raise ValueError("Block-Jacobi requires square diagonal blocks") + self.BR = BR + self.nrow = nrow + self.dtype = dtype + self.device = device + self.ridge = ridge + nnz = int(col.numel()) + if nnz > 0: + crow_l = crow.to(torch.int64) + col_l = col.to(torch.int64) + block_row = torch.repeat_interleave( + torch.arange(nrow, device=device, dtype=torch.int64), + crow_l[1:] - crow_l[:-1], + ) + is_diag = block_row == col_l + self._diag_idx = is_diag.nonzero(as_tuple=False).flatten() + self._diag_rows = block_row[self._diag_idx] + else: + self._diag_idx = torch.empty(0, dtype=torch.int64, device=device) + self._diag_rows = torch.empty(0, dtype=torch.int64, device=device) + self._diag_blocks = torch.zeros(nrow, BR, BR, dtype=dtype, device=device) + self._eye = torch.eye(BR, dtype=dtype, device=device) + self._out_buf = torch.empty(nrow, BR, 1, dtype=dtype, device=device) + self.diag_inv = None + self.refresh(A) + + def refresh(self, A): + _, _, vals3d, _, _, _, _ = _bsr_to_torch(A) + self._diag_blocks.zero_() + if self._diag_idx.numel() > 0: + self._diag_blocks[self._diag_rows] = vals3d[self._diag_idx] + self.diag_inv = torch.linalg.inv(self._diag_blocks + self.ridge * self._eye) + + def __call__(self, x_flat: torch.Tensor) -> torch.Tensor: + x = x_flat.reshape(-1, self.BR, 1) + torch.matmul(self.diag_inv, x, out=self._out_buf) + return self._out_buf.view(-1) + + +def cg(matvec, b, x=None, M=None, tol: float = 1e-5, maxiter: Optional[int] = None, *, r_buf=None, p_buf=None): + if x is None: + x = torch.zeros_like(b) + + if maxiter is None or maxiter == 0: + maxiter = b.numel() + + b_flat = b.reshape(-1) + b_norm_sq = torch.dot(b_flat, b_flat).item() + + if b_norm_sq == 0.0: + return x + + atol_sq = (tol ** 2) * b_norm_sq + + Ax = matvec(x) + if r_buf is None: + r = b - Ax + else: + r = r_buf + torch.sub(b, Ax, out=r) + r_flat = r.reshape(-1) + + z = M(r) if M is not None else r + z_flat = z.reshape(-1) + + if p_buf is None: + p = z.clone() + else: + p = p_buf + p.copy_(z) + + rz = torch.dot(r_flat, z_flat) + r_norm_sq = torch.dot(r_flat, r_flat) + + for _ in range(maxiter): + if r_norm_sq.item() <= atol_sq: + break + + Ap = matvec(p) + Ap_flat = Ap.reshape(-1) + alpha = (rz / torch.dot(p.reshape(-1), Ap_flat)).item() + x.add_(p, alpha=alpha) + r.add_(Ap, alpha=-alpha) + + if M is not None: + z = M(r) + z_flat = z.reshape(-1) + + rz_new = torch.dot(r_flat, z_flat) + beta = (rz_new / rz).item() + p.mul_(beta).add_(z) + rz = rz_new + r_norm_sq = torch.dot(r_flat, r_flat) + + return x + + +__all__ = [ + "sparse_bsr_mv", "sparse_bsr_mm", + "sparse_bsr_transposed", "sparse_bsr_axpy", + "BlockJacobi", "cg", +] From b305f816c6ed6b8c45d8f5b09dc71856b4cd4a48 Mon Sep 17 00:00:00 2001 From: SEOKWOOPARK Date: Sat, 23 May 2026 23:52:13 +0000 Subject: [PATCH 19/28] Remove codes relevant to Chunk --- bae/autograd/graph.py | 29 ++--------------------------- bae/optim/optimizer.py | 37 +++++-------------------------------- 2 files changed, 7 insertions(+), 59 deletions(-) diff --git a/bae/autograd/graph.py b/bae/autograd/graph.py index 9f7599f..8408672 100644 --- a/bae/autograd/graph.py +++ b/bae/autograd/graph.py @@ -180,33 +180,8 @@ def backward(output_): warning("No upstream parameters to compute jacobian") return - total_obs = args[0].shape[0] if len(args) > 0 else 0 - max_obs_per_chunk = 50_000 - if total_obs <= max_obs_per_chunk: - with pp.retain_ltype(): - jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args) - else: - chunks = [] - with pp.retain_ltype(): - for start in range(0, total_obs, max_obs_per_chunk): - end = min(start + max_obs_per_chunk, total_obs) - chunk_args = [a[start:end] for a in args] - chunks.append(list( - torch.vmap(jacrev(func, argnums=argnums))(*chunk_args) - )) - n_args = len(argnums) - jac_blocks = [] - - for i in range(n_args): - parts = [c[i] for c in chunks] - jac_blocks.append(torch.cat(parts, dim=0)) - del parts - - for c in chunks: - c[i] = None - - jac_blocks = tuple(jac_blocks) - del chunks + with pp.retain_ltype(): + jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args) for jacidx, argidx in enumerate(argnums): jac_block = jac_blocks[jacidx] arg = args[argidx] diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index f4a5f4d..895061d 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -16,38 +16,11 @@ class LM(ppLM): - def __init__(self, *args, matrix_free_normal: bool = False, loss_chunk_size: int = 100_000, **kwargs): + def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): self.matrix_free_normal = matrix_free_normal - self.loss_chunk_size = loss_chunk_size super(LM, self).__init__(*args, **kwargs) self.mm = CuSparse() - def _chunked_model_loss(self, input, target=None): - m = self.model - if not isinstance(input, dict): - return m.loss(input, target) - obs_axis_keys = ("points_2d", "camera_indices", "point_indices") - if not all(k in input for k in obs_axis_keys): - return m.loss(input, target) - n = input["points_2d"].shape[0] - chunk = self.loss_chunk_size - if chunk <= 0 or n <= chunk: - return m.loss(input, target) - - total = None - for start in range(0, n, chunk): - end = min(start + chunk, n) - chunk_input = {k: input[k][start:end] for k in obs_axis_keys} - output = m.model_forward(chunk_input) - chunk_residuals = m.residuals(output, target) - if len(m.kernel) > 1: - parts = [k(r.square().sum(-1)).sum() for k, r in zip(m.kernel, chunk_residuals)] - else: - parts = [m.kernel[0](r.square().sum(-1)).sum() for r in chunk_residuals] - chunk_loss = sum(parts) - total = chunk_loss if total is None else total + chunk_loss - return total - @torch.no_grad() def step(self, input, target=None, weight=None): for pg in self.param_groups: @@ -61,7 +34,7 @@ def step(self, input, target=None, weight=None): J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1).to_sparse_csr() del J_list - self.last = self.loss = self.loss if hasattr(self, 'loss') else self._chunked_model_loss(input, target) + self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) self.reject_count = 0 if self.matrix_free_normal: @@ -88,7 +61,7 @@ def step(self, input, target=None, weight=None): print(e, "\nLinear solver failed. Breaking optimization step...") break self.update_parameter(pg['params'], D) - self.loss = self._chunked_model_loss(input, target) + self.loss = self.model.loss(input, target) print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1)) if self.last < self.loss and self.reject_count < self.reject: # reject step @@ -130,7 +103,7 @@ def step(self, input, target=None, weight=None): R = R.detach() torch.cuda.empty_cache() - self.last = self.loss = self.loss if hasattr(self, 'loss') else self._chunked_model_loss(input, target) + self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) J0 = J[0] J1 = J[1] @@ -242,7 +215,7 @@ def schur_matvec(p, _V_i=V_i, _z=schur_Ap_buf): D_p_t = D_p.flatten() D = torch.cat([D_c_t, D_p_t]) self.update_parameter(pg['params'], D) - self.loss = self._chunked_model_loss(input, target) + self.loss = self.model.loss(input, target) print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) self.strategy.update( From a0b4b8b8d4f688aa71d0bf69711132c59907a1af Mon Sep 17 00:00:00 2001 From: SEOKWOOPARK Date: Sun, 24 May 2026 00:49:01 +0000 Subject: [PATCH 20/28] Remove ba_helpers.py --- ba_helpers.py | 46 ---------------------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 ba_helpers.py diff --git a/ba_helpers.py b/ba_helpers.py deleted file mode 100644 index b7f5014..0000000 --- a/ba_helpers.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -import torch.nn as nn -from bae.autograd.function import TrackingTensor, map_transform -from bae.utils.ba import rotate_euler, rotate_quat - -USE_QUATERNIONS = True - -@map_transform -def project(points, camera_params): - """Convert 3-D points to 2-D by projecting onto images.""" - if USE_QUATERNIONS: - points_proj = rotate_quat(points, camera_params[..., :7]) - else: - points_proj = rotate_euler(points, camera_params[..., 3:6]) - points_proj = points_proj + camera_params[..., :3] - points_proj = -points_proj[..., :2] / points_proj[..., 2].unsqueeze(-1) - f = camera_params[..., -3].unsqueeze(-1) - k1 = camera_params[..., -2].unsqueeze(-1) - k2 = camera_params[..., -1].unsqueeze(-1) - - n = torch.sum(points_proj**2, axis=-1, keepdim=True) - r = 1 + k1 * n + k2 * n**2 - points_proj = points_proj * r * f - - return points_proj - -class Reproj(nn.Module): - def __init__(self, camera_params, points_3d): - super().__init__() - self.pose = nn.Parameter(TrackingTensor(camera_params)) - self.points_3d = nn.Parameter(TrackingTensor(points_3d)) - self.pose.trim_SE3_grad = True - - def forward(self, points_2d, camera_indices, point_indices): - camera_params = self.pose - points_3d = self.points_3d - - points_proj = project(points_3d[point_indices], camera_params[camera_indices]) - loss = points_proj - points_2d - return loss - -def least_square_error(camera_params, points_3d, camera_indices, point_indices, points_2d): - model = Reproj(camera_params, points_3d) - loss = model(points_2d, camera_indices, point_indices) - return torch.sum(loss**2, dim=-1).mean() - return torch.sum(loss**2) / 2 \ No newline at end of file From f46fb7487010fb63a3bf1701777e9f521be1a69c Mon Sep 17 00:00:00 2001 From: SEOKWOOPARK Date: Sun, 24 May 2026 02:35:48 +0000 Subject: [PATCH 21/28] Fix a conflict in ba_example.py --- ba_example.py | 137 ++++++++++++++++++-------------------------------- 1 file changed, 48 insertions(+), 89 deletions(-) diff --git a/ba_example.py b/ba_example.py index 9ea8fe7..25789b3 100644 --- a/ba_example.py +++ b/ba_example.py @@ -1,23 +1,15 @@ from time import perf_counter from pathlib import Path from datetime import datetime -import torch -import pypose as pp -import warp as wp -from ba_helpers import Reproj, least_square_error -from bae.optim.optimizer import Schur -from bae.optim.triton_kernel import sparse_bsr_mv -from datapipes.bal_loader import get_problem, read_bal_data -from bae.sparse.py_ops import * -from bae.optim import LM -from bae.utils.pysolvers import PCG, CuDSS -import torch.nn as nn from pypose.autograd.function import psjac - -from bae.autograd.function import TrackingTensor, map_transform from datapipes.bal_loader import get_problem -from bae.optim import LM +from bae.optim.optimizer import Schur +from bae.optim.triton_kernel import sparse_bsr_mv from bae.utils.pysolvers import PCG +import pypose as pp +import torch +import torch.nn as nn +import warp as wp TARGET_DATASET = "trafalgar" TARGET_PROBLEM = "problem-257-65132-pre" @@ -33,8 +25,6 @@ DEVICE = "cuda" OPTIMIZE_INTRINSICS = True NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7 - -USE_QUATERNIONS = True REPORT_WARP_MEMPOOL = True @@ -42,15 +32,12 @@ def _format_bytes(num_bytes: int) -> str: sign = "-" if num_bytes < 0 else "" size = float(abs(num_bytes)) units = ["B", "KiB", "MiB", "GiB", "TiB"] - for unit in units: if size < 1024.0 or unit == units[-1]: break size /= 1024.0 - if unit == "B": return f"{sign}{int(size)} {unit}" - return f"{sign}{size:.2f} {unit}" @@ -80,6 +67,12 @@ def forward(self, observes, cidx, pidx): return points_proj - observes +def least_square_error(camera_params, points, cidx, pidx, observes): + model = Residual(camera_params, points) + loss = model(observes, cidx, pidx) + return torch.sum(loss**2, dim=-1).mean() + + class TrustRegion(pp.optim.strategy.TrustRegion): def update(self, pg, last, loss, J, D, R, *args, **kwargs): Jwp = kwargs.get("Jwp") @@ -87,23 +80,22 @@ def update(self, pg, last, loss, J, D, R, *args, **kwargs): J = Jwp JD = None - for i in range(len(D)): if Jwp is not None: JD_i = sparse_bsr_mv(J[i], D[i].flatten().contiguous()).flatten() else: JD_i = J[i] @ D[i].flatten() JD = JD_i if JD is None else JD + JD_i - + JD = JD[..., None] denom = -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() - + if loss >= last or denom <= 0: quality = -1.0 else: quality = (last - loss) / denom - - pg['radius'] = 1. / pg['damping'] + + pg['radius'] = 1.0 / pg['damping'] if quality > pg['high']: pg['radius'] = pg['up'] * pg['radius'] pg['down'] = self.down @@ -115,53 +107,32 @@ def update(self, pg, last, loss, J, D, R, *args, **kwargs): pg['down'] = pg['down'] * pg['factor'] pg['down'] = max(self.min, min(pg['down'], self.max)) pg['radius'] = max(self.min, min(pg['radius'], self.max)) - pg['damping'] = 1. / pg['radius'] - - -class Adaptive(pp.optim.strategy.Adaptive): - def update(self, pg, last, loss, J, D, R, *args, **kwargs): - J = [i.to_sparse_coo() for i in J] - JD = None - for i in range(len(D)): - if JD is None: - JD = J[i] @ D[i] - else: - JD += J[i] @ D[i] - JD = JD[..., None] - quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() - if quality > pg['high']: - pg['damping'] = pg['damping'] * pg['down'] - elif quality > pg['low']: - pg['damping'] = pg['damping'] - else: - pg['damping'] = pg['damping'] * pg['up'] - pg['damping'] = max(self.min, min(pg['damping'], self.max)) + pg['damping'] = 1.0 / pg['radius'] def main(): - file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}' + file_name = f"{TARGET_DATASET}.{TARGET_PROBLEM}" cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None memory_snapshot_path = None warp_device = None warp_mempool_start_current = None warp_mempool_start_high = None - total_memory = None - nontorch_baseline = None - dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS) + dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET) + print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}") + dataset = { key: value.to(DEVICE) for key, value in dataset.items() if isinstance(value, torch.Tensor) } - input = { - "points_2d": dataset["points_2d"], - "camera_indices": dataset["camera_index_of_observations"], - "point_indices": dataset["point_index_of_observations"], + "observes": dataset["points_2d"], + "cidx": dataset["camera_index_of_observations"], + "pidx": dataset["point_index_of_observations"], } - if DEVICE.startswith("cuda") and torch.cuda.is_available(): + if cuda_device is not None and torch.cuda.is_available(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") snapshot_dir = Path("memory_traces") snapshot_dir.mkdir(exist_ok=True) @@ -173,8 +144,6 @@ def main(): device=cuda_device, clear_history=True, ) - else: - print("CUDA is not available; skipping CUDA memory tracking.") if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"): try: @@ -187,41 +156,37 @@ def main(): except Exception as e: print(f"Warning: failed to query Warp mempool stats: {e}") - model = Reproj( - dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(), - dataset['points_3d'].clone() + model = Residual( + dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(), + dataset["points_3d"].clone(), ).to(DEVICE) - strategy = TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) optimizer = Schur(model, strategy=strategy, solver=solver, reject=30, matrix_free_normal=True) - print('Initial loss:', least_square_error( + print('Loss:', least_square_error( model.pose, - model.points_3d, - dataset['camera_index_of_observations'], - dataset['point_index_of_observations'], - dataset['points_2d'], + model.points, + dataset["camera_index_of_observations"], + dataset["point_index_of_observations"], + dataset["points_2d"], ).item()) + print("Initial loss", optimizer.model.loss(input, None).item()) + if cuda_device is not None and torch.cuda.is_available(): torch.cuda.synchronize(cuda_device) torch.cuda.reset_peak_memory_stats(cuda_device) - free_baseline, total_memory = torch.cuda.mem_get_info(cuda_device) - torch_reserved_baseline = torch.cuda.memory_reserved(cuda_device) - warp_current_baseline = (wp.get_mempool_used_mem_current(warp_device) if warp_device is not None else 0) - nontorch_baseline = ((total_memory - free_baseline) - torch_reserved_baseline - warp_current_baseline) start = perf_counter() for idx in range(20): loss = optimizer.step(input) - print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start) + print("Iteration", idx, "loss", loss.item(), "time", perf_counter() - start) if cuda_device is not None and torch.cuda.is_available(): torch.cuda.synchronize(cuda_device) end = perf_counter() - - print('Time', end - start) + print("Time", end - start) if memory_snapshot_path: torch.cuda.synchronize(cuda_device) @@ -230,39 +195,33 @@ def main(): if cuda_device is not None and torch.cuda.is_available(): peak_allocated = torch.cuda.max_memory_allocated(cuda_device) - try: peak_reserved = torch.cuda.max_memory_reserved(cuda_device) except AttributeError: peak_reserved = torch.cuda.max_memory_cached(cuda_device) - - free_end, _ = torch.cuda.mem_get_info(cuda_device) - torch_reserved_end = torch.cuda.memory_reserved(cuda_device) - warp_current_end = (wp.get_mempool_used_mem_current(warp_device) if warp_device is not None else 0) - nontorch_end = ((total_memory - free_end) - torch_reserved_end - warp_current_end) - module_growth = nontorch_end - nontorch_baseline if nontorch_baseline is not None else None - print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") - print(f"Non-allocator CUDA memory (context + kernel modules): {_format_bytes(nontorch_end)}") - if module_growth is not None: - print(f"Kernel-module growth during run (Triton JIT binaries): {_format_bytes(module_growth)}") if warp_device is not None and warp_mempool_start_current is not None: try: warp_current = wp.get_mempool_used_mem_current(warp_device) warp_high = wp.get_mempool_used_mem_high(warp_device) - print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} (Δ {_format_bytes(warp_current - warp_mempool_start_current)})") - print(f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} (Δ {_format_bytes(warp_high - warp_mempool_start_high)})") + print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} " + f"(Δ {_format_bytes(warp_current - warp_mempool_start_current)})" + ) + print( + f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} " + f"(Δ {_format_bytes(warp_high - warp_mempool_start_high)})" + ) except Exception as e: print(f"Warning: failed to query Warp mempool stats: {e}") print('Ending loss:', least_square_error( model.pose, - model.points_3d, - dataset['camera_index_of_observations'], - dataset['point_index_of_observations'], - dataset['points_2d'], + model.points, + dataset["camera_index_of_observations"], + dataset["point_index_of_observations"], + dataset["points_2d"], ).item()) From 48ad787242f0a246ba791c9dd6945501bc872a84 Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sat, 23 May 2026 20:35:47 -0700 Subject: [PATCH 22/28] Potential fix for pull request finding 'Variable defined multiple times' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- bae/sparse/warp_wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bae/sparse/warp_wrappers.py b/bae/sparse/warp_wrappers.py index 153c48e..dd3f4bb 100644 --- a/bae/sparse/warp_wrappers.py +++ b/bae/sparse/warp_wrappers.py @@ -9,7 +9,6 @@ def torchbsr2wp(tbsr): assert tbsr.layout == torch.sparse_bsr - block_type = wp.types.matrix(shape=tbsr.values().shape[-2:], dtype=wp.dtype_from_torch(tbsr.dtype)) block_type = wp.types.matrix(shape=tuple(tbsr.values().shape[-2:]), dtype=wp.dtype_from_torch(tbsr.dtype)) bsr = wps.bsr_matrix_t(block_type)() bsr.nrow = int(tbsr.shape[0] // block_type._shape_[0]) From 8cc6eb3233aa8395e9497671f961ac52f244243c Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sat, 23 May 2026 21:45:52 -0700 Subject: [PATCH 23/28] Potential fix for pull request finding 'Unused local variable' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- bae/optim/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 895061d..3a15e1d 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -143,7 +143,6 @@ def step(self, input, target=None, weight=None): if self.matrix_free_normal: scratch_obs = torch.empty_like(R_flat) scratch_pts = torch.empty_like(Ip) - z_buf = torch.empty_like(Ic) solver_tol = getattr(self.solver, "tol", None) or 1e-5 solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 From 074b931e8a90e5abff38e0d2a2eff5d5fc9ced53 Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sun, 24 May 2026 04:46:59 +0000 Subject: [PATCH 24/28] minimize diff --- bae/autograd/graph.py | 2 +- bae/sparse/py_ops.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/bae/autograd/graph.py b/bae/autograd/graph.py index 4e16d7a..cd7de28 100644 --- a/bae/autograd/graph.py +++ b/bae/autograd/graph.py @@ -1,3 +1,4 @@ + from typing import Optional import warnings @@ -185,7 +186,6 @@ def backward(output_, is_root=False): if len(argnums) == 0: warnings.warn("No upstream parameters to compute jacobian", stacklevel=2) return - with pp.retain_ltype(): jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args) for jacidx, argidx in enumerate(argnums): diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 6f7ba0b..058daee 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -236,16 +236,3 @@ def bsr2bsc(J): sparse_lib = Library('aten', 'IMPL') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCPU') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCUDA') - -if __name__ == "__main__": - if torch.cuda.is_available(): - crow_indices = torch.tensor([0, 2, 4]) - col_indices = torch.tensor([0, 1, 0, 1]) - values = torch.tensor([[[0, 1, 2], [6, 7, 8]], - [[3, 4, 5], [9, 10, 11]], - [[12, 13, 14], [18, 19, 20]], - [[15, 16, 17], [21, 22, 23]]]) - bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, dtype=torch.float64) - bsr = bsr.to('cuda') - csr = bsr.to_sparse_coo().to_sparse_csr() - output = diagonal_op_triton_(csr) From 4746522a69af273afaabe306d4af1f94bcdd22be Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sun, 24 May 2026 05:04:25 +0000 Subject: [PATCH 25/28] restore pysolvers --- bae/utils/pysolvers.py | 232 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 224 insertions(+), 8 deletions(-) diff --git a/bae/utils/pysolvers.py b/bae/utils/pysolvers.py index 25458d0..0a61e3e 100644 --- a/bae/utils/pysolvers.py +++ b/bae/utils/pysolvers.py @@ -26,16 +26,9 @@ def forward(self, A, b, x=None, M=None) -> torch.Tensor: was_vector = b.dim() == 1 if was_vector: b = b[..., None] - l_diag = A.diagonal().clone() + l_diag = A.diagonal() l_diag[l_diag.abs() < 1e-6] = 1e-6 M = spdiags_((1 / l_diag), None, shape=A.shape, layout=None) - layout = getattr(A, "layout", torch.strided) - if layout == torch.sparse_csr: - # M = M.to_sparse_csr() - pass - # A = M @ A - # elif layout == torch.sparse_bsr and isinstance(A, torch.Tensor): - if A.layout == torch.sparse_bsr: M = M.to_sparse_bsr(blocksize=A.values().shape[-2:]).to(A.device) @@ -65,3 +58,226 @@ def forward(self, A, b): # print(f"Linear Solver Error: {a_err}, relative error: {r_err}") return torch.from_numpy(x).to(A.device) + +# cuda graph version of the solver +class CG_(torch.nn.Module): + r'''The batched linear solver with conjugate gradient method. + + .. math:: + \mathbf{A}_i \bm{x}_i = \mathbf{b}_i, + + where :math:`\mathbf{A}_i \in \mathbb{C}^{M \times N}` and :math:`\bm{b}_i \in + \mathbb{C}^{M \times 1}` are the :math:`i`-th item of batched linear equations. + + This function is a 1:1 replica of `scipy.sparse.linalg.cg `_. + The solution is consistent with the scipy version up to numerical precision. + Variable names are kept the same as the scipy version for easy reference. + We recommend using only non-batched or batch size 1 input for this solver, as + the batched version was not appeared in the original scipy version. When handling + sparse matrices, the batched computation may introduce additional overhead. + + Examples: + >>> # dense example + >>> import pypose.optim.solver as ppos + >>> A = torch.tensor([[0.1802967, 0.3151198, 0.4548111, 0.3860016, 0.2870615], + [0.3151198, 1.4575327, 1.5533425, 1.0540756, 1.0795838], + [0.4548111, 1.5533425, 2.3674474, 1.1222278, 1.2365348], + [0.3860016, 1.0540756, 1.1222278, 1.3748058, 1.2223261], + [0.2870615, 1.0795838, 1.2365348, 1.2223261, 1.2577004]]) + >>> b = torch.tensor([[ 2.64306851], + [-0.03593633], + [ 0.73612658], + [ 0.51501254], + [-0.26689271]]) + >>> solver = ppos.CG() + >>> x = solver(A, b) + tensor([[246.4098], + [ 22.6997], + [-56.9239], + [-161.7914], + [137.2683]]) + + >>> # sparse csr example + >>> import pypose.optim.solver as ppos + >>> crow_indices = torch.tensor([0, 2, 4]) + >>> col_indices = torch.tensor([0, 1, 0, 1]) + >>> values = torch.tensor([1, 2, 3, 4], dtype=torch.float) + >>> A = torch.sparse_csr_tensor(crow_indices, col_indices, values) + >>> A.to_dense() # visualize + tensor([[1., 2.], + [3., 4.]]) + >>> b = torch.tensor([[1.], [2.]]) + >>> solver = ppos.CG() + >>> x = solver(A, b) + tensor([-4.4052e-05, 5.0003e-01]) + + ''' + def __init__(self, maxiter=None, tol=1e-5): + super().__init__() + self.maxiter, self.tol = maxiter, tol + self.graph_first_iter = None + self.graph_subsequent_iter = None + self.static_A_shape, self.static_b_shape, self.static_M_is_none, self.static_device = \ + None, None, None, None + # Tensors for graph capture/replay + self.static_A, self.static_b, self.static_M = None, None, None + self.static_x, self.static_r, self.static_p, self.static_q, self.static_z = \ + None, None, None, None, None + self.static_rho_prev, self.static_rho_cur = None, None + + def forward(self, A: torch.Tensor, b: Tensor, x: Optional[Tensor]=None, + M: Optional[torch.Tensor]=None) -> Tensor: + ''' + Args: + A (Tensor): the input tensor. It is assumed to be a symmetric + positive-definite matrix. Layout is allowed to be COO, CSR, BSR, or dense. + b (Tensor): the tensor on the right hand side. Layout could be sparse or dense + but is only allowed to be a type that is compatible with the layout of A. + In other words, `A @ b` operation must be supported by the layout of A. + x (Tensor, optional): the initial guess for the solution. Default: ``None``. + M (Tensor, optional): the preconditioner for A. Layout is allowed to be COO, + CSR, BSR, or dense. Default: ``None``. + + Return: + Tensor: the solved tensor. Layout is the same as the layout of b. + ''' + if A.ndim == b.ndim + 1: + b = b.unsqueeze(-1) + else: + assert A.ndim == b.ndim, \ + 'The number of dimensions of A and b must be the same or one more than b' + + if x is None: + x = torch.zeros_like(b) + + bnrm2 = torch.linalg.norm(b, dim=0) + if (bnrm2 == 0).all(): + return b + atol = self.tol * bnrm2 + n = b.shape[-2] + + if self.maxiter is None: + maxiter = n * 10 + else: + maxiter = self.maxiter + + # Determine if CUDA graph can be used and if re-capture is needed + use_cuda_graph = A.is_cuda + + if use_cuda_graph: + re_capture_graph = (self.graph_first_iter is None or \ + self.static_A_shape != A.shape or \ + self.static_b_shape != b.shape or \ + self.static_M_is_none != (M is None) or \ + self.static_device != A.device) + + if re_capture_graph: + # Allocate static tensors and capture new graphs + self.static_A = A.clone() + self.static_b = b.clone() + self.static_x = x.clone() # Initial x + self.static_r = b - A @ x # Initial r + self.static_p = torch.zeros_like(b) # Will be updated + self.static_q = torch.empty_like(b) + self.static_z = torch.empty_like(b) + + # Initialize rho_prev and rho_cur with shape [1, 1] + self.static_rho_prev = torch.zeros(1, 1, device=A.device) + self.static_rho_cur = torch.zeros(1, 1, device=A.device) + + self.static_M_is_none = (M is None) + self.static_device = A.device + self.static_A_shape = A.shape + self.static_b_shape = b.shape + + if M is not None: + self.static_M = M.clone() + else: + self.static_M = None + + # Capture first iteration graph + self.graph_first_iter = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + with torch.cuda.graph(self.graph_first_iter): + # Operations for first iteration + if not self.static_M_is_none: + torch.matmul(self.static_M, self.static_r, out=self.static_z) + else: + self.static_z.copy_(self.static_r) # z = r.clone() + self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) + self.static_p.copy_(self.static_z) # p = z.clone() + torch.matmul(self.static_A, self.static_p, out=self.static_q) + alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) + self.static_x.add_(alpha * self.static_p) + self.static_r.sub_(alpha * self.static_q) + self.static_rho_prev.copy_(self.static_rho_cur) + + # Capture subsequent iteration graph + self.graph_subsequent_iter = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + with torch.cuda.graph(self.graph_subsequent_iter): + # Operations for subsequent iterations + if not self.static_M_is_none: + torch.matmul(self.static_M, self.static_r, out=self.static_z) + else: + self.static_z.copy_(self.static_r) # z = r.clone() + self.static_rho_cur.copy_(torch.matmul(self.static_r.mT, self.static_z)) + beta = self.static_rho_cur / self.static_rho_prev + self.static_p.mul_(beta).add_(self.static_z) + torch.matmul(self.static_A, self.static_p, out=self.static_q) + alpha = self.static_rho_cur / torch.matmul(self.static_p.mT, self.static_q) + self.static_x.add_(alpha * self.static_p) + self.static_r.sub_(alpha * self.static_q) + self.static_rho_prev.copy_(self.static_rho_cur) + + # Now run the loop using the (newly captured or existing) graphs + self.static_A.copy_(A) + self.static_b.copy_(b) + self.static_x.copy_(x) + self.static_r.copy_(b - A @ x) # Initial r + if M is not None: + self.static_M.copy_(M) + + # First iteration + self.graph_first_iter.replay() + if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): + return self.static_x + + # Subsequent iterations + for iteration in range(1, maxiter): + self.graph_subsequent_iter.replay() + if (torch.linalg.norm(self.static_r, dim=0) < atol).all(): + return self.static_x + return self.static_x + + else: # A is not on CUDA, or other conditions not met for graph, run original Python loop + r = b - A @ x if x.any() else b.clone() + rho_prev, p = None, None + + q = torch.empty_like(b) + if M is not None: + z = torch.empty_like(b) + else: + z = r.clone() + + for iteration in range(maxiter): + if (torch.linalg.norm(r, dim=0) < atol).all(): + return x + + if M is not None: + torch.matmul(M, r, out=z) + rho_cur = torch.matmul(r.mT, z) + if iteration > 0: + beta = rho_cur / rho_prev + p.mul_(beta).add_(z) + else: # First spin + p = z.clone() + + torch.matmul(A, p, out=q) + alpha = rho_cur / torch.matmul(p.mT, q) + x += alpha * p + r -= alpha * q + rho_prev = rho_cur + + return x From 7f3ea3d3bf0ad92bea1a970bc8181c74e462025c Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sun, 24 May 2026 05:06:13 +0000 Subject: [PATCH 26/28] revert import shuffle --- bae/optim/optimizer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 3a15e1d..f018c69 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -1,12 +1,8 @@ -import torch -import pypose as pp from functools import partial +import torch from pypose.optim import LevenbergMarquardt as ppLM -from .triton_kernel import ( - sparse_bsr_mm, sparse_bsr_mv, - sparse_bsr_transposed, sparse_bsr_axpy, - BlockJacobi, cg, -) +import pypose as pp + from ..autograd.graph import jacobian from ..autograd.function import TrackingTensor from ..sparse.py_ops import diagonal_op_, inv_op @@ -14,6 +10,12 @@ from ..utils.linear_operator import NormalMatVec from ..utils.parameter import parameter_update_shape +from .triton_kernel import ( + sparse_bsr_mm, sparse_bsr_mv, + sparse_bsr_transposed, sparse_bsr_axpy, + BlockJacobi, cg, +) + class LM(ppLM): def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): @@ -234,4 +236,3 @@ def schur_matvec(p, _V_i=V_i, _z=schur_Ap_buf): break return self.loss - From d3e24d9605950c1531d30f91ed55863eecd99368 Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sun, 24 May 2026 05:08:28 +0000 Subject: [PATCH 27/28] restore LM --- bae/optim/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index f018c69..2e335d9 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -87,7 +87,7 @@ def update_parameter(self, params, step): if param.shape[-1] > 7: param[:, 7:] += step_view[..., 6:] else: - param.add_(d.view(param.shape)) + param.add_(step_view) class Schur(LM): From 04908d93ec750448ee551f47be507829ef6c92fb Mon Sep 17 00:00:00 2001 From: Zitong Zhan Date: Sun, 24 May 2026 05:12:54 +0000 Subject: [PATCH 28/28] fix import order ba example --- ba_example.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ba_example.py b/ba_example.py index 25789b3..8d0e41a 100644 --- a/ba_example.py +++ b/ba_example.py @@ -1,15 +1,17 @@ from time import perf_counter -from pathlib import Path from datetime import datetime +from pathlib import Path + +import pypose as pp +import torch +import torch.nn as nn +import warp as wp from pypose.autograd.function import psjac + from datapipes.bal_loader import get_problem from bae.optim.optimizer import Schur from bae.optim.triton_kernel import sparse_bsr_mv from bae.utils.pysolvers import PCG -import pypose as pp -import torch -import torch.nn as nn -import warp as wp TARGET_DATASET = "trafalgar" TARGET_PROBLEM = "problem-257-65132-pre" @@ -226,4 +228,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()