diff --git a/.gitignore b/.gitignore index 0f6aced..749df84 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ tmp_debug/* task.md *.pickle *.png +.warp_cache/* .DS_Store tmp/* examples/module/pgo/data/* diff --git a/ba_example.py b/ba_example.py index 28ff359..8d0e41a 100644 --- a/ba_example.py +++ b/ba_example.py @@ -1,25 +1,46 @@ from time import perf_counter +from datetime import datetime +from pathlib import Path import pypose as pp import torch import torch.nn as nn +import warp as wp from pypose.autograd.function import psjac from datapipes.bal_loader import get_problem -from bae.optim import LM +from bae.optim.optimizer import Schur +from bae.optim.triton_kernel import sparse_bsr_mv from bae.utils.pysolvers import PCG TARGET_DATASET = "trafalgar" TARGET_PROBLEM = "problem-257-65132-pre" -# other options: # TARGET_DATASET = "ladybug" # TARGET_PROBLEM = "problem-1723-156502-pre" # TARGET_DATASET = "dubrovnik" # TARGET_PROBLEM = "problem-356-226730-pre" +# TARGET_DATASET = "final" +# TARGET_PROBLEM = "problem-13682-4456117-pre" +# TARGET_DATASET = "venice" +# TARGET_PROBLEM = "problem-1778-993923-pre" DEVICE = "cuda" OPTIMIZE_INTRINSICS = True NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7 +REPORT_WARP_MEMPOOL = True + + +def _format_bytes(num_bytes: int) -> str: + sign = "-" if num_bytes < 0 else "" + size = float(abs(num_bytes)) + units = ["B", "KiB", "MiB", "GiB", "TiB"] + for unit in units: + if size < 1024.0 or unit == units[-1]: + break + size /= 1024.0 + if unit == "B": + return f"{sign}{int(size)} {unit}" + return f"{sign}{size:.2f} {unit}" @psjac @@ -54,7 +75,51 @@ def least_square_error(camera_params, points, cidx, pidx, observes): return torch.sum(loss**2, dim=-1).mean() +class TrustRegion(pp.optim.strategy.TrustRegion): + def update(self, pg, last, loss, J, D, R, *args, **kwargs): + Jwp = kwargs.get("Jwp") + if Jwp is not None: + J = Jwp + + JD = None + for i in range(len(D)): + if Jwp is not None: + JD_i = sparse_bsr_mv(J[i], D[i].flatten().contiguous()).flatten() + else: + JD_i = J[i] @ D[i].flatten() + JD = JD_i if JD is None else JD + JD_i + + JD = JD[..., None] + denom = -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze() + + if loss >= last or denom <= 0: + quality = -1.0 + else: + quality = (last - loss) / denom + + pg['radius'] = 1.0 / pg['damping'] + if quality > pg['high']: + pg['radius'] = pg['up'] * pg['radius'] + pg['down'] = self.down + elif quality > pg['low']: + pg['radius'] = pg['radius'] + pg['down'] = self.down + else: + pg['radius'] = pg['radius'] * pg['down'] + pg['down'] = pg['down'] * pg['factor'] + pg['down'] = max(self.min, min(pg['down'], self.max)) + pg['radius'] = max(self.min, min(pg['radius'], self.max)) + pg['damping'] = 1.0 / pg['radius'] + + def main(): + file_name = f"{TARGET_DATASET}.{TARGET_PROBLEM}" + cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None + memory_snapshot_path = None + warp_device = None + warp_mempool_start_current = None + warp_mempool_start_high = None + dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET) print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}") @@ -69,13 +134,37 @@ def main(): "pidx": dataset["point_index_of_observations"], } + if cuda_device is not None and torch.cuda.is_available(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + snapshot_dir = Path("memory_traces") + snapshot_dir.mkdir(exist_ok=True) + memory_snapshot_path = snapshot_dir / f"{file_name}_cuda_memory_{timestamp}.pickle" + torch.cuda.memory._record_memory_history( + enabled="all", + context="all", + stacks="python", + device=cuda_device, + clear_history=True, + ) + + if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"): + try: + if wp.is_cuda_available(): + warp_device = wp.get_device("cuda:0" if DEVICE == "cuda" else DEVICE) + if not wp.is_mempool_enabled(warp_device): + wp.set_mempool_enabled(warp_device, True) + warp_mempool_start_current = wp.get_mempool_used_mem_current(warp_device) + warp_mempool_start_high = wp.get_mempool_used_mem_high(warp_device) + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + model = Residual( dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(), dataset["points_3d"].clone(), ).to(DEVICE) - strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4) + strategy = TrustRegion(up=2.0, down=0.5**4) solver = PCG(tol=1e-4, maxiter=250) - optimizer = LM(model, strategy=strategy, solver=solver, reject=30) + optimizer = Schur(model, strategy=strategy, solver=solver, reject=30, matrix_free_normal=True) print('Loss:', least_square_error( model.pose, @@ -87,15 +176,48 @@ def main(): print("Initial loss", optimizer.model.loss(input, None).item()) + if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) + torch.cuda.reset_peak_memory_stats(cuda_device) + start = perf_counter() for idx in range(20): loss = optimizer.step(input) print("Iteration", idx, "loss", loss.item(), "time", perf_counter() - start) - torch.cuda.synchronize() + if cuda_device is not None and torch.cuda.is_available(): + torch.cuda.synchronize(cuda_device) end = perf_counter() print("Time", end - start) + if memory_snapshot_path: + torch.cuda.synchronize(cuda_device) + torch.cuda.memory._dump_snapshot(str(memory_snapshot_path)) + print(f"CUDA memory snapshot saved to {memory_snapshot_path}") + + if cuda_device is not None and torch.cuda.is_available(): + peak_allocated = torch.cuda.max_memory_allocated(cuda_device) + try: + peak_reserved = torch.cuda.max_memory_reserved(cuda_device) + except AttributeError: + peak_reserved = torch.cuda.max_memory_cached(cuda_device) + print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}") + print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}") + + if warp_device is not None and warp_mempool_start_current is not None: + try: + warp_current = wp.get_mempool_used_mem_current(warp_device) + warp_high = wp.get_mempool_used_mem_high(warp_device) + print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} " + f"(Δ {_format_bytes(warp_current - warp_mempool_start_current)})" + ) + print( + f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} " + f"(Δ {_format_bytes(warp_high - warp_mempool_start_high)})" + ) + except Exception as e: + print(f"Warning: failed to query Warp mempool stats: {e}") + print('Ending loss:', least_square_error( model.pose, model.points, diff --git a/bae/optim/optimizer.py b/bae/optim/optimizer.py index 1fcb4a5..2e335d9 100644 --- a/bae/optim/optimizer.py +++ b/bae/optim/optimizer.py @@ -2,17 +2,24 @@ import torch from pypose.optim import LevenbergMarquardt as ppLM import pypose as pp + from ..autograd.graph import jacobian from ..autograd.function import TrackingTensor -from ..sparse.py_ops import diagonal_op_ +from ..sparse.py_ops import diagonal_op_, inv_op from ..sparse.spgemm import CuSparse +from ..utils.linear_operator import NormalMatVec from ..utils.parameter import parameter_update_shape - +from .triton_kernel import ( + sparse_bsr_mm, sparse_bsr_mv, + sparse_bsr_transposed, sparse_bsr_axpy, + BlockJacobi, cg, +) class LM(ppLM): - def __init__(self, *args, **kwargs): + def __init__(self, *args, matrix_free_normal: bool = False, **kwargs): + self.matrix_free_normal = matrix_free_normal super(LM, self).__init__(*args, **kwargs) self.mm = CuSparse() @@ -20,26 +27,38 @@ def __init__(self, *args, **kwargs): def step(self, input, target=None, weight=None): for pg in self.param_groups: weight = self.weight if weight is None else weight - R = list(self.model(input)) - R = R[0] - J = jacobian(R, pg['params']) + R = self.model(input)[0] + J_list = jacobian(R, pg['params']) + if isinstance(R, TrackingTensor): R = R.tensor() - J = torch.cat([j.to_sparse_coo() for j in J], dim=-1) + + J = torch.cat([j.to_sparse_coo() for j in J_list], dim=-1).to_sparse_csr() + del J_list self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) - J_T = J.mT self.reject_count = 0 - J_T = J_T.to_sparse_csr() - J = J.to_sparse_csr() - A = self.mm(J_T, J) - diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + if self.matrix_free_normal: + diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max']) + A = NormalMatVec(J, damping=0.0, diag=diag) + rhs = -(A._get_Jt() @ R.view(-1, 1)) + diag_scale = 1.0 + else: + J_T = J.mT.to_sparse_csr() + rhs = -J_T @ R.view(-1, 1) + A = self.mm(J_T, J) + del J_T + diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) while self.last <= self.loss: - diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) + if self.matrix_free_normal: + diag_scale *= 1.0 + pg['damping'] + A.set_damping(diag_scale - 1.0) + else: + diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping'])) try: - D = self.solver(A, -J_T @ R.view(-1, 1)) + D = self.solver(A, rhs) except Exception as e: print(e, "\nLinear solver failed. Breaking optimization step...") break @@ -69,3 +88,151 @@ def update_parameter(self, params, step): param[:, 7:] += step_view[..., 6:] else: param.add_(step_view) + + +class Schur(LM): + @torch.no_grad() + def step(self, input, target=None, weight=None): + for pg in self.param_groups: + self.reject_count = 0 + weight = self.weight if weight is None else weight + R = self.model(input, target)[0] + J = jacobian(R, pg['params']) + + if isinstance(R, TrackingTensor): + R = R.tensor() + else: + R = R.detach() + torch.cuda.empty_cache() + + self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target) + + J0 = J[0] + J1 = J[1] + if self.matrix_free_normal: + J0t = sparse_bsr_transposed(J0) + U = sparse_bsr_mm(J0t, J0) + del J0t + J1t = sparse_bsr_transposed(J1) + V = sparse_bsr_mm(J1t, J1) + del J1t + else: + J0t = sparse_bsr_transposed(J0) + J1t = sparse_bsr_transposed(J1) + U = sparse_bsr_mm(J0t, J0) + V = sparse_bsr_mm(J1t, J1) + W = sparse_bsr_mm(J0t, J1) + Wt = sparse_bsr_transposed(W) + del J0t, J1t + + diagonal_op_(U, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + diagonal_op_(V, op=partial(torch.clamp_, min=pg['min'], max=pg['max'])) + R_flat = R.reshape(-1).contiguous() + Ic = sparse_bsr_mv(J0, R_flat, alpha=-1.0, transpose=True) + Ip = sparse_bsr_mv(J1, R_flat, alpha=-1.0, transpose=True) + rhs_c = torch.empty_like(Ic) + rhs_p = torch.empty_like(Ip) + scratch_pts2 = torch.empty_like(Ip) + schur_Ap_buf = torch.empty_like(Ic) + v_Ap_buf = torch.empty_like(Ip) + D_c = torch.empty_like(Ic) + D_p = torch.empty_like(Ip) + cg_r_buf_c = torch.empty_like(Ic) + cg_p_buf_c = torch.empty_like(Ic) + cg_r_buf_p = torch.empty_like(Ip) + cg_p_buf_p = torch.empty_like(Ip) + + if self.matrix_free_normal: + scratch_obs = torch.empty_like(R_flat) + scratch_pts = torch.empty_like(Ip) + + solver_tol = getattr(self.solver, "tol", None) or 1e-5 + solver_maxiter = getattr(self.solver, "maxiter", 0) or 0 + mm_cache_WV_i = {} if not self.matrix_free_normal else None + mm_cache_WVi_Wt = {} if not self.matrix_free_normal else None + axpy_cache_schur = {} if not self.matrix_free_normal else None + v_M = BlockJacobi(V) + schur_M = BlockJacobi(U) if self.matrix_free_normal else None + + while self.last <= self.loss: + damp = partial(torch.mul, other=1+pg['damping']) + diagonal_op_(U, op=damp) + diagonal_op_(V, op=damp) + V_i = inv_op(V) + + if self.matrix_free_normal: + def schur_matvec(p, _V_i=V_i, _z=schur_Ap_buf): + sparse_bsr_mv(J0, p, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J1, scratch_obs, y=scratch_pts, beta=0.0, transpose=True) + sparse_bsr_mv(_V_i, scratch_pts, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(J1, scratch_pts2, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J0, scratch_obs, y=_z, alpha=-1.0, beta=0.0, transpose=True) + sparse_bsr_mv(U, p, y=_z, alpha=1.0, beta=1.0) + return _z + + matvec_fn = schur_matvec + schur_M.refresh(U) + rhs_c.copy_(Ic) + sparse_bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(J1, scratch_pts2, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J0, scratch_obs, y=rhs_c, alpha=-1.0, beta=1.0, transpose=True) + else: + WV_i = sparse_bsr_mm(W, V_i, topology_cache=mm_cache_WV_i) + WVi_Wt = sparse_bsr_mm(WV_i, Wt, topology_cache=mm_cache_WVi_Wt) + del WV_i + schur_op = sparse_bsr_axpy(WVi_Wt, U, alpha=-1.0, + topology_cache=axpy_cache_schur) + del WVi_Wt + matvec_fn = lambda p, _S=schur_op, _y=schur_Ap_buf: \ + sparse_bsr_mv(_S, p, y=_y, beta=0.0) + if schur_M is None: + schur_M = BlockJacobi(schur_op) + else: + schur_M.refresh(schur_op) + rhs_c.copy_(Ic) + sparse_bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0) + sparse_bsr_mv(W, scratch_pts2, y=rhs_c, alpha=-1.0, beta=1.0) + + D_c.zero_() + cg(matvec_fn, rhs_c, x=D_c, M=schur_M, + tol=solver_tol, maxiter=solver_maxiter, + r_buf=cg_r_buf_c, p_buf=cg_p_buf_c) + + rhs_p.copy_(Ip) + if self.matrix_free_normal: + sparse_bsr_mv(J0, D_c, y=scratch_obs, beta=0.0) + sparse_bsr_mv(J1, scratch_obs, y=rhs_p, alpha=-1.0, beta=1.0, transpose=True) + else: + sparse_bsr_mv(Wt, D_c, y=rhs_p, alpha=-1.0, beta=1.0) + + v_M.refresh(V) + D_p.zero_() + cg(lambda p, _V=V, _y=v_Ap_buf: sparse_bsr_mv(_V, p, y=_y, beta=0.0), + rhs_p, x=D_p, M=v_M, + tol=solver_tol, maxiter=solver_maxiter, + r_buf=cg_r_buf_p, p_buf=cg_p_buf_p) + + D_c_t = D_c.flatten() + D_p_t = D_p.flatten() + D = torch.cat([D_c_t, D_p_t]) + self.update_parameter(pg['params'], D) + self.loss = self.model.loss(input, target) + print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping']) + + self.strategy.update( + pg, + last=self.last, + loss=self.loss, + J=J, + Jwp=[J0, J1], + D=[D_c_t, D_p_t], + R=R_flat.view(-1, 1), + ) + + if self.last < self.loss and self.reject_count < self.reject: # reject step + self.update_parameter(params=pg['params'], step=-D) + self.loss, self.reject_count = self.last, self.reject_count + 1 + else: + break + + return self.loss diff --git a/bae/optim/triton_kernel.py b/bae/optim/triton_kernel.py new file mode 100644 index 0000000..70d0ae8 --- /dev/null +++ b/bae/optim/triton_kernel.py @@ -0,0 +1,718 @@ +from typing import Optional, Tuple +import torch +import triton +import triton.language as tl + + +def _bsr_to_torch( + A, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Tuple[int, int], Tuple[int, int], torch.dtype, torch.device]: + + if isinstance(A, torch.Tensor) and A.layout == torch.sparse_bsr: + crow = A.crow_indices() + col = A.col_indices() + values = A.values() + BR, BC = values.shape[-2], values.shape[-1] + nrow = A.shape[0] // BR + ncol = A.shape[1] // BC + return crow, col, values, (BR, BC), (nrow, ncol), values.dtype, values.device + + raise TypeError(f"Unsupported BSR matrix type: {type(A)}") + + +def _flatten_vec(v: torch.Tensor) -> torch.Tensor: + if not v.is_contiguous(): + raise ValueError("Vector must be contiguous to share memory with Triton kernel") + return v.view(-1) + + +def _build_bsr_output(crow, col, vals, nrow, ncol, BR, BC): + return torch.sparse_bsr_tensor( + crow_indices=crow.to(torch.int32), + col_indices=col.to(torch.int32), + values=vals, + size=(nrow * BR, ncol * BC), + ) + + +_TRITON_DTYPES = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.float64: tl.float64, + torch.bfloat16: tl.bfloat16, +} + + +def _triton_dtype(t: torch.dtype): + try: + return _TRITON_DTYPES[t] + except KeyError as exc: + raise TypeError(f"Unsupported dtype for Triton BSR kernels: {t}") from exc + + +def _next_pow2(n: int) -> int: + return 1 << max(0, (int(n) - 1)).bit_length() + + +def _uncompress_rows(crow: torch.Tensor, nnz: int) -> torch.Tensor: + nrow = crow.numel() - 1 + + if nnz == 0: + return torch.empty(0, dtype=torch.int32, device=crow.device) + + counts = (crow[1:] - crow[:-1]).to(torch.int64) + + return torch.repeat_interleave(torch.arange(nrow, device=crow.device, dtype=torch.int32), counts) + + +@triton.jit +def _bsr_mv_fwd_kernel( + A_crow, A_col, A_val, + X, Y, + alpha, beta, + nrow, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, + DTYPE: tl.constexpr, +): + + row = tl.program_id(0).to(tl.int32) + if row >= nrow: + return + + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + + y_off = row * BR + r_idx + if beta == 0.0: + acc = tl.zeros((BLK_R,), dtype=DTYPE) + else: + prev = tl.load(Y + y_off, mask=rmask, other=0.0).to(DTYPE) + acc = beta * prev + + if alpha != 0.0: + beg = tl.load(A_crow + row).to(tl.int32) + end = tl.load(A_crow + row + 1).to(tl.int32) + block_size = BR * BC + partial = tl.zeros((BLK_R,), dtype=DTYPE) + n = end - beg + for i in tl.range(0, n): + blk = beg + i + col = tl.load(A_col + blk).to(tl.int32) + v_off = blk * block_size + r_idx[:, None] * BC + c_idx[None, :] + block_vals = tl.load( + A_val + v_off, + mask=rmask[:, None] & cmask[None, :], other=0.0, + ).to(DTYPE) + x_off = col * BC + c_idx + x_vals = tl.load(X + x_off, mask=cmask, other=0.0).to(DTYPE) + partial += tl.sum(block_vals * x_vals[None, :], axis=1) + acc += alpha * partial + + tl.store(Y + y_off, acc, mask=rmask) + + +@triton.jit +def _bsr_mv_trans_kernel( + A_crow, A_col, A_val, + X, Y, + alpha, + nrow, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, + DTYPE: tl.constexpr, +): + + row = tl.program_id(0).to(tl.int32) + + if row >= nrow: + return + + if alpha == 0.0: + return + + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + x_off = row * BR + r_idx + x_vals = tl.load(X + x_off, mask=rmask, other=0.0).to(DTYPE) + beg = tl.load(A_crow + row).to(tl.int32) + end = tl.load(A_crow + row + 1).to(tl.int32) + block_size = BR * BC + n = end - beg + + for i in tl.range(0, n): + blk = beg + i + col = tl.load(A_col + blk).to(tl.int32) + v_off = blk * block_size + r_idx[:, None] * BC + c_idx[None, :] + block_vals = tl.load(A_val + v_off, mask=rmask[:, None] & cmask[None, :], other=0.0).to(DTYPE) + contrib = tl.sum(block_vals * x_vals[:, None], axis=0) * alpha + tl.atomic_add(Y + (col * BC + c_idx), contrib, mask=cmask) + + +@triton.jit +def _gather_transpose_kernel( + src_ptr, sort_idx_ptr, dst_ptr, + nnz, + BR: tl.constexpr, BC: tl.constexpr, + BLK_R: tl.constexpr, BLK_C: tl.constexpr, +): + i = tl.program_id(0).to(tl.int32) + if i >= nnz: + return + src_idx = tl.load(sort_idx_ptr + i).to(tl.int64) + r_idx = tl.arange(0, BLK_R) + c_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR + cmask = c_idx < BC + + src_off = src_idx * (BR * BC) + r_idx[:, None] * BC + c_idx[None, :] + block = tl.load( + src_ptr + src_off, + mask=rmask[:, None] & cmask[None, :], other=0.0, + ) + i64 = i.to(tl.int64) + dst_off = i64 * (BC * BR) + c_idx[None, :] * BR + r_idx[:, None] + tl.store(dst_ptr + dst_off, block, mask=rmask[:, None] & cmask[None, :]) + + +@triton.jit +def _bsr_scale_kernel(Y, beta, n, BLOCK: tl.constexpr, DTYPE: tl.constexpr): + pid = tl.program_id(0) + off = pid * BLOCK + tl.arange(0, BLOCK) + mask = off < n + + if beta == 0.0: + tl.store(Y + off, tl.zeros((BLOCK,), dtype=DTYPE), mask=mask) + else: + v = tl.load(Y + off, mask=mask, other=0.0).to(DTYPE) + tl.store(Y + off, beta * v, mask=mask) + + +@triton.jit +def _bsr_mm_numeric_kernel( + A_crow, A_col, A_val, + B_crow, B_col, B_val, + C_blk_row, C_col, C_val, + alpha, + BR_A: tl.constexpr, BC_A: tl.constexpr, BC_B: tl.constexpr, + BLK_R: tl.constexpr, BLK_K: tl.constexpr, BLK_C: tl.constexpr, + BSEARCH_ITERS: tl.constexpr, + DTYPE: tl.constexpr, +): + + c_blk = tl.program_id(0).to(tl.int32) + c_row = tl.load(C_blk_row + c_blk).to(tl.int32) + c_col = tl.load(C_col + c_blk).to(tl.int32) + r_idx = tl.arange(0, BLK_R) + k_idx = tl.arange(0, BLK_K) + s_idx = tl.arange(0, BLK_C) + rmask = r_idx < BR_A + kmask = k_idx < BC_A + smask = s_idx < BC_B + a_block_size = BR_A * BC_A + b_block_size = BC_A * BC_B + c_block_size = BR_A * BC_B + contrib = tl.zeros((BLK_R, BLK_C), dtype=DTYPE) + + if alpha != 0.0: + a_beg = tl.load(A_crow + c_row).to(tl.int32) + a_end = tl.load(A_crow + c_row + 1).to(tl.int32) + n_a = a_end - a_beg + for i in tl.range(0, n_a): + a_blk = a_beg + i + k = tl.load(A_col + a_blk).to(tl.int32) + b_beg = tl.load(B_crow + k).to(tl.int32) + b_end = tl.load(B_crow + k + 1).to(tl.int32) + lo = b_beg + hi = b_end + + for _ in tl.range(0, BSEARCH_ITERS): + mid = (lo + hi) // 2 + cond = lo < hi + safe_mid = tl.where(cond, mid, b_beg) + v = tl.load(B_col + safe_mid).to(tl.int32) + go_right = cond & (v < c_col) + go_left = cond & (v >= c_col) + lo = tl.where(go_right, mid + 1, lo) + hi = tl.where(go_left, mid, hi) + + in_range = lo < b_end + safe_idx = tl.where(in_range, lo, b_beg) + cur = tl.load(B_col + safe_idx).to(tl.int32) + exists = in_range & (cur == c_col) + safe_b_blk = tl.where(exists, lo, 0) + a_off = a_blk * a_block_size + r_idx[:, None] * BC_A + k_idx[None, :] + a_block = tl.load(A_val + a_off, mask=rmask[:, None] & kmask[None, :], other=0.0).to(DTYPE) + b_off = safe_b_blk * b_block_size + k_idx[:, None] * BC_B + s_idx[None, :] + b_block = tl.load(B_val + b_off,mask=(kmask[:, None] & smask[None, :]) & exists, other=0.0).to(DTYPE) + contrib += tl.sum(a_block[:, :, None] * b_block[None, :, :], axis=1) + + v_off = c_blk * c_block_size + r_idx[:, None] * BC_B + s_idx[None, :] + prev = tl.load(C_val + v_off, mask=rmask[:, None] & smask[None, :], other=0.0).to(DTYPE) + tl.store(C_val + v_off, prev + alpha * contrib, mask=rmask[:, None] & smask[None, :]) + + +def _bsr_mm_topology( + A_crow: torch.Tensor, A_col: torch.Tensor, + B_crow: torch.Tensor, B_col: torch.Tensor, + nrow_C: int, ncol_C: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + + device = A_crow.device + A_crow_l = A_crow.to(torch.int64) + A_col_l = A_col.to(torch.int64) + B_crow_l = B_crow.to(torch.int64) + B_col_l = B_col.to(torch.int64) + a_nnz = A_col_l.numel() + empty = ( + torch.zeros(nrow_C + 1, dtype=torch.int32, device=device), + torch.empty(0, dtype=torch.int32, device=device), + ) + if a_nnz == 0 or B_col_l.numel() == 0: + return empty + + b_lens = B_crow_l[1:] - B_crow_l[:-1] + counts = b_lens[A_col_l] + a_blk_offsets = torch.zeros(a_nnz + 1, dtype=torch.int64, device=device) + a_blk_offsets[1:] = counts.cumsum(0) + total = int(a_blk_offsets[-1].item()) + + if total == 0: + return empty + + a_blk_per_t = torch.repeat_interleave(torch.arange(a_nnz, device=device, dtype=torch.int64), counts) + pos_in_row = ( + torch.arange(total, device=device, dtype=torch.int64) + - a_blk_offsets[a_blk_per_t] + ) + b_starts = B_crow_l[A_col_l] + del B_crow_l + triplet_j = B_col_l[b_starts[a_blk_per_t] + pos_in_row] + del b_starts, pos_in_row, B_col_l + + nrow_A = A_crow_l.numel() - 1 + a_row_per_blk = torch.repeat_interleave( + torch.arange(nrow_A, device=device, dtype=torch.int64), + A_crow_l[1:] - A_crow_l[:-1], + ) + triplet_i = a_row_per_blk[a_blk_per_t] + key = triplet_i * ncol_C + triplet_j + sorted_key, _ = torch.sort(key) + keep = torch.ones(total, dtype=torch.bool, device=device) + keep[1:] = sorted_key[1:] != sorted_key[:-1] + unique_key = sorted_key[keep] + out_i = (unique_key // ncol_C).to(torch.int64) + out_j = (unique_key % ncol_C).to(torch.int32) + counts_per_row = torch.bincount(out_i, minlength=nrow_C) + crow_l = torch.zeros(nrow_C + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts_per_row.cumsum(0) + + return crow_l.to(torch.int32), out_j + + +def sparse_bsr_mv( + A, + x, + y=None, + alpha: float = 1.0, + beta: float = 0.0, + transpose: bool = False, + work_buffer=None, +): + crow, col, vals3d, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(A) + + if transpose: + out_blocks, out_block_size = ncol, BC + in_blocks, in_block_size = nrow, BR + else: + out_blocks, out_block_size = nrow, BR + in_blocks, in_block_size = ncol, BC + + expected_y = out_blocks * out_block_size + vals3d = vals3d.contiguous() + crow_i = crow.to(torch.int32).contiguous() + col_i = col.to(torch.int32).contiguous() + + x_t = _flatten_vec(x) + + if x_t.numel() != in_blocks * in_block_size: + raise ValueError(f"x has {x_t.numel()} scalars, expected {in_blocks * in_block_size}") + + if x_t.dtype != dtype: + raise TypeError(f"x dtype {x_t.dtype} != A dtype {dtype}") + + return_obj = y + + if y is None: + y_t = torch.empty(expected_y, dtype=dtype, device=device) + return_obj = y_t + beta = 0.0 + else: + y_t = _flatten_vec(y) + if y_t.numel() != expected_y: + raise ValueError(f"y has {y_t.numel()} scalars, expected {expected_y}") + if y_t.dtype != dtype: + raise TypeError(f"y dtype {y_t.dtype} != A dtype {dtype}") + + if x_t.data_ptr() == y_t.data_ptr(): + if work_buffer is None: + x_t = y_t.clone() + else: + wb = _flatten_vec(work_buffer) + wb.copy_(y_t) + x_t = wb + + BLK_R = max(_next_pow2(BR), 1) + BLK_C = max(_next_pow2(BC), 1) + DTYPE = _triton_dtype(dtype) + + if transpose: + n = y_t.numel() + SCALE_BLOCK = 1024 + scale_grid = ((n + SCALE_BLOCK - 1) // SCALE_BLOCK,) + _bsr_scale_kernel[scale_grid](y_t, float(beta), n, BLOCK=SCALE_BLOCK, DTYPE=DTYPE) + + if alpha != 0.0: + _bsr_mv_trans_kernel[(nrow,)]( + crow_i, col_i, vals3d, + x_t, y_t, + float(alpha), + nrow, + BR=BR, BC=BC, BLK_R=BLK_R, BLK_C=BLK_C, DTYPE=DTYPE, + ) + else: + _bsr_mv_fwd_kernel[(nrow,)]( + crow_i, col_i, vals3d, + x_t, y_t, + float(alpha), float(beta), + nrow, + BR=BR, BC=BC, BLK_R=BLK_R, BLK_C=BLK_C, DTYPE=DTYPE, + ) + return return_obj + + +def sparse_bsr_mm(A, B, alpha: float = 1.0, *, topology_cache=None): + A_crow, A_col, A_val, (BR_A, BC_A), (nrow_A, ncol_A), dtype, device = _bsr_to_torch(A) + B_crow, B_col, B_val, (BR_B, BC_B), (nrow_B, ncol_B), dtype_b, device_b = _bsr_to_torch(B) + + if dtype != dtype_b: + raise ValueError(f"A and B dtypes differ: {dtype} vs {dtype_b}") + + if device != device_b: + raise ValueError(f"A and B on different devices: {device} vs {device_b}") + + if BC_A != BR_B: + raise ValueError(f"Block-shape mismatch: A.block_shape[1]={BC_A}, B.block_shape[0]={BR_B}") + + if ncol_A != nrow_B: + raise ValueError(f"Block-row/col mismatch: A.ncol={ncol_A}, B.nrow={nrow_B}") + + nrow_C, ncol_C = nrow_A, ncol_B + BR_C, BC_C = BR_A, BC_B + + A_crow_i = A_crow.to(torch.int32).contiguous() + A_col_i = A_col.to(torch.int32).contiguous() + B_crow_i = B_crow.to(torch.int32).contiguous() + B_col_i = B_col.to(torch.int32).contiguous() + + if alpha == 0.0 or A_col.numel() == 0 or B_col.numel() == 0: + crow = torch.zeros(nrow_C + 1, dtype=torch.int32, device=device) + col = torch.empty(0, dtype=torch.int32, device=device) + vals = torch.empty((0, BR_C, BC_C), dtype=dtype, device=device) + return _build_bsr_output(crow, col, vals, nrow_C, ncol_C, BR_C, BC_C) + + cached = topology_cache is not None and "C_crow" in topology_cache + if cached: + C_crow = topology_cache["C_crow"] + C_col = topology_cache["C_col"] + C_blk_row = topology_cache["C_blk_row"] + else: + C_crow, C_col = _bsr_mm_topology(A_crow_i, A_col_i, B_crow_i, B_col_i, nrow_C, ncol_C) + out_nnz_tmp = int(C_col.numel()) + C_blk_row = _uncompress_rows(C_crow, out_nnz_tmp) + if topology_cache is not None: + topology_cache["C_crow"] = C_crow + topology_cache["C_col"] = C_col + topology_cache["C_blk_row"] = C_blk_row + + out_nnz = int(C_col.numel()) + C_vals = torch.zeros((out_nnz, BR_C, BC_C), dtype=dtype, device=device) + if out_nnz == 0: + return _build_bsr_output(C_crow, C_col, C_vals, nrow_C, ncol_C, BR_C, BC_C) + + BLK_R = max(_next_pow2(BR_C), 1) + BLK_K = max(_next_pow2(BC_A), 1) + BLK_C = max(_next_pow2(BC_C), 1) + DTYPE = _triton_dtype(dtype) + BSEARCH_ITERS = max(1, (max(1, ncol_C) - 1).bit_length() + 1) + + A_val_c = A_val.contiguous() + B_val_c = B_val.contiguous() + + _bsr_mm_numeric_kernel[(out_nnz,)]( + A_crow_i, A_col_i, A_val_c, + B_crow_i, B_col_i, B_val_c, + C_blk_row, C_col, C_vals, + float(alpha), + BR_A=BR_A, BC_A=BC_A, BC_B=BC_B, + BLK_R=BLK_R, BLK_K=BLK_K, BLK_C=BLK_C, + BSEARCH_ITERS=BSEARCH_ITERS, + DTYPE=DTYPE, + ) + return _build_bsr_output(C_crow, C_col, C_vals, nrow_C, ncol_C, BR_C, BC_C) + + +def sparse_bsr_transposed(A): + crow, col, vals3d, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(A) + + crow_i = crow.to(torch.int32).contiguous() + col_i = col.to(torch.int32).contiguous() + nnz = int(col_i.numel()) + + if nnz == 0: + new_crow = torch.zeros(ncol + 1, dtype=torch.int32, device=device) + new_col = torch.empty(0, dtype=torch.int32, device=device) + new_vals = torch.empty((0, BC, BR), dtype=dtype, device=device) + return _build_bsr_output(new_crow, new_col, new_vals, ncol, nrow, BC, BR) + + old_row = _uncompress_rows(crow_i, nnz).to(torch.int64) + new_row = col_i.to(torch.int64) + new_col = old_row + + new_nrow, new_ncol = ncol, nrow + key = new_row * new_ncol + new_col + _, sort_idx = torch.sort(key, stable=True) + + sorted_new_row = new_row[sort_idx] + sorted_new_col = new_col[sort_idx].to(torch.int32) + vals3d_c = vals3d.contiguous() + permuted_vals = torch.empty((nnz, BC, BR), dtype=dtype, device=device) + sort_idx64 = sort_idx.to(torch.int64).contiguous() + BLK_R_ = max(_next_pow2(BR), 1) + BLK_C_ = max(_next_pow2(BC), 1) + + _gather_transpose_kernel[(nnz,)]( + vals3d_c, sort_idx64, permuted_vals, + nnz, BR=BR, BC=BC, BLK_R=BLK_R_, BLK_C=BLK_C_, + ) + + counts = torch.bincount(sorted_new_row, minlength=new_nrow) + crow_l = torch.zeros(new_nrow + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts.cumsum(0) + new_crow = crow_l.to(torch.int32) + + return _build_bsr_output(new_crow, sorted_new_col, permuted_vals, new_nrow, new_ncol, BC, BR) + + +def sparse_bsr_axpy(x, y, alpha: float = 1.0, *, topology_cache=None): + x_crow, x_col, x_vals, (BR, BC), (nrow, ncol), dtype, device = _bsr_to_torch(x) + y_crow, y_col, y_vals, (BR_y, BC_y), (nrow_y, ncol_y), dtype_y, device_y = _bsr_to_torch(y) + + if (BR, BC) != (BR_y, BC_y): + raise ValueError(f"Block shapes differ: {(BR, BC)} vs {(BR_y, BC_y)}") + if (nrow, ncol) != (nrow_y, ncol_y): + raise ValueError(f"Block-matrix shapes differ: {(nrow, ncol)} vs {(nrow_y, ncol_y)}") + if dtype != dtype_y: + raise ValueError(f"Dtypes differ: {dtype} vs {dtype_y}") + if device != device_y: + raise ValueError(f"Devices differ: {device} vs {device_y}") + + x_crow_i = x_crow.to(torch.int32).contiguous() + x_col_i = x_col.to(torch.int32).contiguous() + y_crow_i = y_crow.to(torch.int32).contiguous() + y_col_i = y_col.to(torch.int32).contiguous() + x_nnz = int(x_col_i.numel()) + y_nnz = int(y_col_i.numel()) + x_vals_c = x_vals.contiguous() + y_vals_c = y_vals.contiguous() + + if x_nnz == 0: + return _build_bsr_output( + y_crow_i.clone(), y_col_i.clone(), y_vals_c.clone(), + nrow, ncol, BR, BC, + ) + if y_nnz == 0: + return _build_bsr_output( + x_crow_i.clone(), x_col_i.clone(), + (alpha * x_vals_c).contiguous(), nrow, ncol, BR, BC, + ) + + cached = topology_cache is not None and "out_crow" in topology_cache + + if cached: + out_crow = topology_cache["out_crow"] + out_col = topology_cache["out_col"] + x_target = topology_cache["x_target"] + x_pick = topology_cache["x_pick"] + y_target = topology_cache["y_target"] + y_pick = topology_cache["y_pick"] + out_nnz = int(out_col.numel()) + else: + x_block_row = _uncompress_rows(x_crow_i, x_nnz).to(torch.int64) + y_block_row = _uncompress_rows(y_crow_i, y_nnz).to(torch.int64) + + all_row = torch.cat([x_block_row, y_block_row]) + all_col = torch.cat([x_col_i.to(torch.int64), y_col_i.to(torch.int64)]) + all_src = torch.cat([ + torch.zeros(x_nnz, dtype=torch.int8, device=device), + torch.ones(y_nnz, dtype=torch.int8, device=device), + ]) + all_idx = torch.cat([ + torch.arange(x_nnz, dtype=torch.int64, device=device), + torch.arange(y_nnz, dtype=torch.int64, device=device), + ]) + + key = all_row * ncol + all_col + _, sort_idx = torch.sort(key, stable=True) + sorted_key = key[sort_idx] + sorted_src = all_src[sort_idx] + sorted_idx = all_idx[sort_idx] + keep = torch.ones(sorted_key.numel(), dtype=torch.bool, device=device) + keep[1:] = sorted_key[1:] != sorted_key[:-1] + unique_key = sorted_key[keep] + out_nnz = int(unique_key.numel()) + + out_row = (unique_key // ncol).to(torch.int64) + out_col = (unique_key % ncol).to(torch.int32) + output_index = keep.long().cumsum(0) - 1 + + x_mask = sorted_src == 0 + y_mask = ~x_mask + x_pick = sorted_idx[x_mask].contiguous() + x_target = output_index[x_mask].contiguous() + y_pick = sorted_idx[y_mask].contiguous() + y_target = output_index[y_mask].contiguous() + + counts = torch.bincount(out_row, minlength=nrow) + crow_l = torch.zeros(nrow + 1, dtype=torch.int64, device=device) + crow_l[1:] = counts.cumsum(0) + out_crow = crow_l.to(torch.int32) + + if topology_cache is not None: + topology_cache["out_crow"] = out_crow + topology_cache["out_col"] = out_col + topology_cache["x_target"] = x_target + topology_cache["x_pick"] = x_pick + topology_cache["y_target"] = y_target + topology_cache["y_pick"] = y_pick + + out_vals = torch.zeros((out_nnz, BR, BC), dtype=dtype, device=device) + if x_pick.numel() > 0: + out_vals.index_add_(0, x_target, alpha * x_vals_c[x_pick]) + if y_pick.numel() > 0: + out_vals.index_add_(0, y_target, y_vals_c[y_pick]) + + return _build_bsr_output(out_crow, out_col, out_vals, nrow, ncol, BR, BC) + + +class BlockJacobi: + def __init__(self, A, ridge: float = 1e-12): + crow, col, vals3d, (BR, BC), (nrow, _), dtype, device = _bsr_to_torch(A) + if BR != BC: + raise ValueError("Block-Jacobi requires square diagonal blocks") + self.BR = BR + self.nrow = nrow + self.dtype = dtype + self.device = device + self.ridge = ridge + nnz = int(col.numel()) + if nnz > 0: + crow_l = crow.to(torch.int64) + col_l = col.to(torch.int64) + block_row = torch.repeat_interleave( + torch.arange(nrow, device=device, dtype=torch.int64), + crow_l[1:] - crow_l[:-1], + ) + is_diag = block_row == col_l + self._diag_idx = is_diag.nonzero(as_tuple=False).flatten() + self._diag_rows = block_row[self._diag_idx] + else: + self._diag_idx = torch.empty(0, dtype=torch.int64, device=device) + self._diag_rows = torch.empty(0, dtype=torch.int64, device=device) + self._diag_blocks = torch.zeros(nrow, BR, BR, dtype=dtype, device=device) + self._eye = torch.eye(BR, dtype=dtype, device=device) + self._out_buf = torch.empty(nrow, BR, 1, dtype=dtype, device=device) + self.diag_inv = None + self.refresh(A) + + def refresh(self, A): + _, _, vals3d, _, _, _, _ = _bsr_to_torch(A) + self._diag_blocks.zero_() + if self._diag_idx.numel() > 0: + self._diag_blocks[self._diag_rows] = vals3d[self._diag_idx] + self.diag_inv = torch.linalg.inv(self._diag_blocks + self.ridge * self._eye) + + def __call__(self, x_flat: torch.Tensor) -> torch.Tensor: + x = x_flat.reshape(-1, self.BR, 1) + torch.matmul(self.diag_inv, x, out=self._out_buf) + return self._out_buf.view(-1) + + +def cg(matvec, b, x=None, M=None, tol: float = 1e-5, maxiter: Optional[int] = None, *, r_buf=None, p_buf=None): + if x is None: + x = torch.zeros_like(b) + + if maxiter is None or maxiter == 0: + maxiter = b.numel() + + b_flat = b.reshape(-1) + b_norm_sq = torch.dot(b_flat, b_flat).item() + + if b_norm_sq == 0.0: + return x + + atol_sq = (tol ** 2) * b_norm_sq + + Ax = matvec(x) + if r_buf is None: + r = b - Ax + else: + r = r_buf + torch.sub(b, Ax, out=r) + r_flat = r.reshape(-1) + + z = M(r) if M is not None else r + z_flat = z.reshape(-1) + + if p_buf is None: + p = z.clone() + else: + p = p_buf + p.copy_(z) + + rz = torch.dot(r_flat, z_flat) + r_norm_sq = torch.dot(r_flat, r_flat) + + for _ in range(maxiter): + if r_norm_sq.item() <= atol_sq: + break + + Ap = matvec(p) + Ap_flat = Ap.reshape(-1) + alpha = (rz / torch.dot(p.reshape(-1), Ap_flat)).item() + x.add_(p, alpha=alpha) + r.add_(Ap, alpha=-alpha) + + if M is not None: + z = M(r) + z_flat = z.reshape(-1) + + rz_new = torch.dot(r_flat, z_flat) + beta = (rz_new / rz).item() + p.mul_(beta).add_(z) + rz = rz_new + r_norm_sq = torch.dot(r_flat, r_flat) + + return x + + +__all__ = [ + "sparse_bsr_mv", "sparse_bsr_mm", + "sparse_bsr_transposed", "sparse_bsr_axpy", + "BlockJacobi", "cg", +] diff --git a/bae/sparse/py_ops.py b/bae/sparse/py_ops.py index 142903f..058daee 100644 --- a/bae/sparse/py_ops.py +++ b/bae/sparse/py_ops.py @@ -172,7 +172,12 @@ def inv_op(input): bsr_values = input.values() # 1 + 2 dimensional inv_values = torch.linalg.inv(bsr_values) - return torch.sparse_bsc_tensor(crow_indices, col_indices, inv_values) + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, + col_indices=col_indices, + values=inv_values, + size=input.shape, + dtype=input.dtype, ) def to_cooh(input): crow_indices = input.crow_indices() # b + 1 dimensional @@ -231,4 +236,3 @@ def bsr2bsc(J): sparse_lib = Library('aten', 'IMPL') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCPU') sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCUDA') - diff --git a/bae/utils/__init__.py b/bae/utils/__init__.py index e69de29..48ea6d3 100644 --- a/bae/utils/__init__.py +++ b/bae/utils/__init__.py @@ -0,0 +1 @@ +from .linear_operator import NormalMatVec diff --git a/bae/utils/linear_operator.py b/bae/utils/linear_operator.py new file mode 100644 index 0000000..070badf --- /dev/null +++ b/bae/utils/linear_operator.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Optional, Union, Any, Tuple + +import torch +from torch import Tensor + + +class NormalMatVec: + r"""Matrix-free normal-equation linear operator. + + Given a Jacobian J (dense, CSR/COO, or BSR), this operator represents: + + A x = J^T (J x) + damping * diag(J^T J) * x + + where diag(J^T J) is computed as the column-wise sum of squares of J. + """ + + def __init__( + self, + J: Tensor, + damping: Union[float, Tensor] = 0.0, + diag: Optional[Tensor] = None, + ): + if not torch.is_tensor(J): + raise TypeError("J must be a torch.Tensor") + if J.ndim != 2: + raise ValueError("J must be 2-D") + + self.J: Tensor = J + self._Jt: Optional[Tensor] = None + + self.device = J.device + self.dtype = J.dtype + self.layout = J.layout if J.layout != torch.sparse_coo else torch.sparse_csr + + self.shape: Tuple[int, int] = (J.shape[1], J.shape[1]) + self.ndim: int = 2 + + self._diag: Tensor = diag if diag is not None else self._compute_diag(J) + if self._diag.ndim != 1 or self._diag.numel() != J.shape[1]: + raise ValueError("diag must be 1-D with length equal to J.shape[1]") + + self.set_damping(damping) + + def set_damping(self, damping: Union[float, Tensor]) -> None: + if isinstance(damping, Tensor): + if damping.numel() != 1: + raise ValueError("damping tensor must be scalar") + self.damping = damping.to(device=self.device, dtype=self.dtype) + else: + self.damping = float(damping) + + def diagonal(self) -> Tensor: + damp = self._damping_value() + if damp is None: + return self._diag + return self._diag * (1.0 + damp) + + def matvec(self, x: Tensor) -> Tensor: + if x.ndim == 1: + x2d = x.unsqueeze(-1) + elif x.ndim == 2: + x2d = x + else: + raise ValueError("x must be 1-D or 2-D") + + if x2d.device != self.device or x2d.dtype != self.dtype: + x2d = x2d.to(device=self.device, dtype=self.dtype) + + y = self.J @ x2d + Jt = self._get_Jt() + z = Jt @ y + + damp = self._damping_value() + if damp is not None: + z = z + damp * self._diag.unsqueeze(-1) * x2d + + return z.squeeze(-1) if x.ndim == 1 else z + + def __call__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + def __matmul__(self, x: Tensor) -> Tensor: + return self.matvec(x) + + @classmethod + def __torch_function__( # type: ignore[override] + cls, func: Any, types: Any, args: Tuple[Any, ...] = (), kwargs: Optional[dict] = None + ): + if kwargs is None: + kwargs = {} + if func is torch.matmul and len(args) >= 2: + A, B = args[0], args[1] + if isinstance(A, cls): + out = kwargs.get("out", None) + result = A.matvec(B) + if out is not None: + out_tensor = out[0] if isinstance(out, tuple) else out + out_tensor.copy_(result) + return out_tensor + return result + return NotImplemented + + def _get_Jt(self) -> Tensor: + if self._Jt is None: + Jt = self.J.mT + if Jt.layout == torch.sparse_csc: + Jt = Jt.to_sparse_csr() + elif Jt.layout == torch.sparse_bsc: + bs = Jt.values().shape[-2:] + Jt = Jt.to_sparse_bsr(blocksize=bs) + self._Jt = Jt + return self._Jt + + def _damping_value(self) -> Optional[Tensor]: + if isinstance(self.damping, Tensor): + if self.damping.item() == 0.0: + return None + return self.damping + if self.damping == 0.0: + return None + return torch.tensor(self.damping, device=self.device, dtype=self.dtype) + + @staticmethod + def _compute_diag(J: Tensor) -> Tensor: + if J.layout == torch.strided: + return J.square().sum(dim=0) + + if J.layout == torch.sparse_bsr: + values = J.values() + dm, dn = values.shape[-2], values.shape[-1] + col_blocks = J.col_indices() + contrib = values.square().sum(dim=-2) # (nnz_blocks, dn) + offsets = torch.arange(dn, device=contrib.device, dtype=col_blocks.dtype) + cols = (col_blocks[:, None] * dn + offsets[None, :]).reshape(-1).to(torch.int64) + contrib_flat = contrib.reshape(-1) + diag = torch.zeros(J.shape[1], device=contrib.device, dtype=contrib.dtype) + diag.scatter_add_(0, cols, contrib_flat) + return diag + + if J.layout == torch.sparse_csr: + values = J.values() + col = J.col_indices().to(torch.int64) + v2 = values.square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=values.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + if J.layout == torch.sparse_coo: + Jc = J.coalesce() + col = Jc.indices()[1].to(torch.int64) + v2 = Jc.values().square() + if v2.ndim > 1: + v2 = v2.reshape(v2.shape[0], -1).sum(dim=-1) + diag = torch.zeros(J.shape[1], device=Jc.device, dtype=v2.dtype) + diag.scatter_add_(0, col, v2) + return diag + + raise NotImplementedError(f"Unsupported J layout: {J.layout}") + diff --git a/tests/test_normal_operator.py b/tests/test_normal_operator.py new file mode 100644 index 0000000..b101514 --- /dev/null +++ b/tests/test_normal_operator.py @@ -0,0 +1,87 @@ +import pytest +import torch + +from bae.utils.linear_operator import NormalMatVec +from bae.utils.pysolvers import PCG + + +def test_normal_matvec_dense_matches_explicit(): + torch.manual_seed(0) + m, n = 8, 5 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + + op = NormalMatVec(J) + y_op = op.matvec(x) + y_ex = J.mT @ (J @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_dense_damping_matches_explicit(): + torch.manual_seed(1) + m, n = 7, 4 + J = torch.randn(m, n, dtype=torch.float64) + x = torch.randn(n, dtype=torch.float64) + damping = 0.3 + + diag = J.square().sum(dim=0) + op = NormalMatVec(J, damping=damping) + y_op = op @ x + y_ex = J.mT @ (J @ x) + damping * diag * x + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_csr_matches_explicit_and_cached_diag(): + torch.manual_seed(2) + m, n = 10, 6 + dense = torch.randn(m, n, dtype=torch.float64) + mask = (torch.rand_like(dense) < 0.5) + dense = dense * mask + J = dense.to_sparse_csr() + x = torch.randn(n, dtype=torch.float64) + + diag = dense.square().sum(dim=0) + op = NormalMatVec(J, diag=diag) + y_op = op.matvec(x) + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_normal_matvec_sparse_bsr_matches_explicit(): + torch.manual_seed(3) + m, n = 6, 4 + dense = torch.randn(m, n, dtype=torch.float64).contiguous() + x = torch.randn(n, dtype=torch.float64) + + try: + J_bsr = dense.to_sparse_bsr(blocksize=(2, 2)) + except Exception: + pytest.skip("BSR conversion not supported on this device/build.") + + op = NormalMatVec(J_bsr) + y_op = op @ x + y_ex = dense.mT @ (dense @ x) + + torch.testing.assert_close(y_op, y_ex, rtol=1e-10, atol=1e-10) + + +def test_pcg_smoke_with_normal_operator(): + torch.manual_seed(4) + m, n = 12, 5 + J = torch.randn(m, n, dtype=torch.float64) + damping = 1e-3 + op = NormalMatVec(J, damping=damping) + + b = torch.randn(n, dtype=torch.float64) + solver = PCG(tol=1e-10, maxiter=200) + x_pcg = solver(op, b) + + diag = J.square().sum(dim=0) + A_dense = J.mT @ J + damping * torch.diag(diag) + x_ex = torch.linalg.solve(A_dense, b) + + torch.testing.assert_close(x_pcg, x_ex, rtol=1e-6, atol=1e-6) +