diff --git a/.agents/plans/2026-06-10-ivf-wave.md b/.agents/plans/2026-06-10-ivf-wave.md new file mode 100644 index 0000000..251ff7b --- /dev/null +++ b/.agents/plans/2026-06-10-ivf-wave.md @@ -0,0 +1,27 @@ +# IVF / coarse-quantizer wave + repo finish-up + +## Context + +Final roadmap item for quantvec: IVF (inverted-file) coarse quantizer for sublinear search on 10M+ corpora. User decisions: **full parity** (TurboQuantIndex + IdMapIndex + Collection), **full remove() support**, release as **v0.0.3**. Architecture: IVF as an opt-in option `ivf: { nlist, nprobe? }` on TurboQuantIndex (like `calibrate`/`fastscan`) — parity through the existing wrapper layering is then nearly free. Format VERSION bumps 1→2 (v2-only readers, sanctioned by repo ADR D-010: pre-1.0 bump-and-rewrite, no legacy readers). + +## Steps (in dependency order) + +1. **`src/core/kmeans.ts`** (new) — seeded k-means++ init + Lloyd, `kmeans(data, m, {k, dim, rng, spherical, maxIterations=25})` → `{centroids, assignments, iterations}`. Spherical mode (cosine/dot): renormalize centroids each round, zero-mean keeps previous direction. Convergence: zero assignment changes. Empty-cluster repair: re-seed from farthest point (deterministic first-max). `KMeansError` codes `INVALID_K|INVALID_DIM|INVALID_LENGTH`; k ∈ [2, m]. Cite Lloyd 1982 / Arthur & Vassilvitskii 2007. Tests: determinism, separated blobs, repair, spherical unit norms, k=m, validation. + +2. **`src/core/search.ts`** — refactor query prep (validation, norms, rotation, calibration dual, LUT) into private `prepareScan`; add exported `searchSlots(db, query, k, slots: Int32Array, opts)` scanning only given slots, honoring full-length mask, same errors + new `SearchError` code `INVALID_SLOT`. Tests: ≡ searchFlat over all slots (exact), subset-only results, mask interaction, empty slots, INVALID_SLOT. + +3. **`src/index/coarse.ts`** (new) — `CoarseQuantizer`: `train(vecs, nlist, nprobe, metric, dim, seed)` (sample min(m, 64·nlist) via partial Fisher-Yates with domain-separated rng seed; spherical for cosine/dot), `fromState(centroids, listForSlot, ...)`, `assign`, `addSlot`, `swapRemove(i, last)` (two-step dance: A) swap-pop slot i from its list, B) renumber last→i reading posForSlot AFTER step A; handles i===last and same-list-tail), `clear()` (keeps centroids), `probe(query, nprobe)` → concatenated Int32Array of top-nprobe lists' slots via TopK. State: `#postings: number[][]`, `#listForSlot`, `#posForSlot`. `defaultNprobe(nlist) = max(1, ceil(nlist/8))`. Tests: invariant fuzz (postings[listForSlot[s]][posForSlot[s]] === s after every op), train determinism, probe(nlist) = all slots, fromState round-trip. + +4. **`src/index/turboquant-index.ts`** — `TurboQuantIndexOptions.ivf?: IvfOptions {nlist (int, [2, 2^22]), nprobe? (int, [1,nlist])}`; `IndexSearchOptions.nprobe?`; `IndexError` codes `INVALID_NLIST|INVALID_NPROBE`; getter `ivfActive`. `@internal trainIvfFromBatch(vecs)` mirroring fitCalibrationFromBatch (train iff first batch ≥ nlist, else freeze flat forever); hooks: `add()` after calibration call, `#appendOne` → `addSlot(slot, rawVec)`, `swapRemove` → coarse.swapRemove(i, last), `clear()` → coarse.clear(). Search: IVF branch before WASM block (`probe` + `searchSlots`); nprobe ignored when flat (documented); WASM/FastScan bypassed under IVF (documented as future wave). toPayload/fromPayload carry ivf state; fromPayload freezes. Tests: validation, train/freeze semantics, nprobe=nlist ≡ flat (exact scores), recall sanity on 64-cluster mixture, remove parity vs flat twin, round-trip, mask interaction. + +5. **`src/io/serialize.ts`** — VERSION = 2 (read accepts only 2). New section after calibration, before ids: flag u8 ∈ {0,1}; if 1: nlist u32 + nprobe u32 + centroids nlist·dim f32 + listForSlot n·u32 (postings rebuilt on load). `DeserializeError` code `BAD_IVF`; bounds-check before allocation; validate nlist ∈ [2, 2^22], nprobe ∈ [1, nlist], centroids finite, listForSlot entries < nlist. `IndexPayload.ivf?`. Tests: round-trips (±ivf, ±calibration, both kinds), crafted v1 → BAD_VERSION, all corrupt-IVF branches, n=0 with ivf. + +6. **Plumbing** — IdMapIndex: call `trainIvfFromBatch` in addWithIds beside calibration; `IdMapSearchOptions.nprobe?` forwarded; `ivfActive` getter. Collection: `CollectionConfig.ivf?`, constructor passes through; `SearchParams.nprobe?` forwarded. `src/index.ts`: export `IvfOptions` type. Tests in id-map-index.test.ts + collection.test.ts (passthrough, remove/delete under ivf, filter+ivf, round-trip). + +7. **Bench + docs + release** — `benchmarks/ivf.ts` (synthetic 64-cluster mixture, dim 768, n 20k, nlist 128; flat baseline + nprobe ∈ {1,2,4,8,16,128}; recall@1/10/100, QPS, METRIC lines, results JSON) + `bench:ivf` script. Docs: README (scope note, quickstart, roadmap table ✅), docs/roadmap.md (Planned→Shipped), serialization.md (v2 layout), api-reference.md, guide.md, architecture.md (module map), benchmarks.md. Version bump 0.0.3 (publish trigger on main merge). Branch `feat/ivf`, PR to main; do NOT merge without user confirmation (merge auto-publishes to npm). + +## Verification +- nprobe=nlist ≡ searchFlat exactly (IVF analog of the WASM≡scalar oracle). +- Invariant fuzz over add/remove/clear; determinism (same seed+data ⇒ byte-identical toBytes()). +- Every BAD_IVF branch hit by a crafted-buffer test; no allocation before bounds check. +- `bun run typecheck && lint && format:check && test:coverage` (90% global) && build; run bench:ivf, commit results JSON. diff --git a/README.md b/README.md index 78da245..1d654a2 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,9 @@ scalar codebook is fully determined by `(dim, bits)` with **no data and ~zero in | Dependencies | **Zero** runtime dependencies | > **Scope:** quantvec is a _flat quantized index_ — O(n) scan over compact codes (à la FAISS -> `IndexPQFastScan`). Great recall and throughput up to ~1–10M vectors. An IVF coarse-quantizer -> for larger corpora is on the roadmap. +> `IndexPQFastScan`) — with an opt-in **IVF coarse quantizer** (`ivf: { nlist }`) that probes +> only the nearest cells for sublinear search on large corpora (**11× QPS at equal recall** +> measured at 20k vectors; the gain grows with n). A 1M × 1536-d corpus (e.g. OpenAI `text-embedding-ada-002`) is **6.1 GB as float32**. At 4 bits quantvec packs it into **~780 MB** (7.92×); at 2 bits, **~390 MB** (15.67×) — with **94%+ @@ -78,6 +79,16 @@ exact rescore of the candidate pool): const index = new TurboQuantIndex({ dim: 1536, bits: 4, fastscan: true }); ``` +For large corpora, enable the **IVF coarse quantizer** — k-means cells are trained from the +first add (needs ≥ `nlist` vectors; ~32·nlist recommended) and queries probe only the nearest +`nprobe` cells (sublinear scan; ~11× QPS at equal recall on clustered data): + +```ts +const index = new TurboQuantIndex({ dim: 1536, ivf: { nlist: 1024 } }); +index.add(corpus); // first batch trains + freezes the cells +index.search(query, 10, { nprobe: 32 }); // per-query recall/speed knob +``` + ### Stable ids: `IdMapIndex` ```ts @@ -275,18 +286,18 @@ Full results and JSON in [`benchmarks/`](./benchmarks/). ## Roadmap -| Status | Item | -| ------ | -------------------------------------------------------------------------------------------------------- | -| ✅ | Core math: rotation, Beta/Lloyd-Max codebooks, encode pipeline, flat nibble-LUT search | -| ✅ | `TurboQuantIndex`, `IdMapIndex`, versioned serialization, Node fs helpers | -| ✅ | True 2/3/4-bit **bit-packed serialization** (7.9–15.7× compression) | -| ✅ | **FWHT rotation** for power-of-two dims (O(d·log d), ~25× faster encode) | -| ✅ | **TQ+ per-coordinate calibration** (opt-in; data-dependent) | -| ✅ | **Exact WASM scoring kernel** (AssemblyScript, bit-identical to scalar, ~1.3× query) | -| ✅ | **v128 FastScan kernel** (blocked-nibble swizzle + exact rescore, **~5.7× query**) | -| ✅ | **Ergonomic `createCollection`** with typed payloads and filter DSL | -| ✅ | Real-dataset benchmarks: SIFT-small + GloVe-200 + dbpedia-OpenAI-100k (results in `benchmarks/results/`) | -| 📋 | IVF / coarse-quantizer for 10M+ corpora | +| Status | Item | +| ------ | --------------------------------------------------------------------------------------------------------- | +| ✅ | Core math: rotation, Beta/Lloyd-Max codebooks, encode pipeline, flat nibble-LUT search | +| ✅ | `TurboQuantIndex`, `IdMapIndex`, versioned serialization, Node fs helpers | +| ✅ | True 2/3/4-bit **bit-packed serialization** (7.9–15.7× compression) | +| ✅ | **FWHT rotation** for power-of-two dims (O(d·log d), ~25× faster encode) | +| ✅ | **TQ+ per-coordinate calibration** (opt-in; data-dependent) | +| ✅ | **Exact WASM scoring kernel** (AssemblyScript, bit-identical to scalar, ~1.3× query) | +| ✅ | **v128 FastScan kernel** (blocked-nibble swizzle + exact rescore, **~5.7× query**) | +| ✅ | **Ergonomic `createCollection`** with typed payloads and filter DSL | +| ✅ | Real-dataset benchmarks: SIFT-small + GloVe-200 + dbpedia-OpenAI-100k (results in `benchmarks/results/`) | +| ✅ | **IVF / coarse-quantizer** for 10M+ corpora (k-means cells, full remove parity, ~11× QPS at equal recall) | --- diff --git a/benchmarks/ivf.ts b/benchmarks/ivf.ts new file mode 100644 index 0000000..702adf5 --- /dev/null +++ b/benchmarks/ivf.ts @@ -0,0 +1,216 @@ +// quantvec — IVF coarse-quantizer benchmark: speedup vs recall against the flat scan. +// +// Measures the IVF value proposition on clustered data (the regime IVF exists for): +// a seeded Gaussian-mixture corpus, an exact float32 cosine ground truth, a flat +// TurboQuantIndex baseline, and the same index with `ivf` enabled swept across +// nprobe values. At nprobe = nlist the IVF results are exactly the flat scan's +// (the searchSlots oracle), so the sweep shows the recall/QPS trade-off cleanly. +// +// Self-contained and deterministic. Emits a human table, `METRIC key=value` lines, +// and a JSON results file. Run: `npm run bench:ivf` (env: DIM, N, NQ, CLUSTERS, NLIST). + +import { mkdirSync, writeFileSync } from 'node:fs'; +import { join } from 'node:path'; +import { TurboQuantIndex } from '../src/index/turboquant-index'; + +// ── Deterministic PRNG + Gaussian (mulberry32 + Box–Muller) ──────────────────── +function mulberry32(seed: number): () => number { + let a = seed >>> 0; + return () => { + a |= 0; + a = (a + 0x6d2b79f5) | 0; + let t = Math.imul(a ^ (a >>> 15), 1 | a); + t = (t + Math.imul(t ^ (t >>> 7), 61 | t)) ^ t; + return ((t ^ (t >>> 14)) >>> 0) / 4294967296; + }; +} + +function gaussian(rng: () => number): number { + let u = 0; + let v = 0; + while (u === 0) u = rng(); + while (v === 0) v = rng(); + return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v); +} + +/** Gaussian mixture: `clusters` centers (sigma 5), unit-sigma points around them. */ +function makeClustered( + count: number, + dim: number, + clusters: number, + rng: () => number, +): { vectors: Float32Array[]; centers: Float32Array[] } { + const centers = Array.from({ length: clusters }, () => { + const c = new Float32Array(dim); + for (let i = 0; i < dim; i++) c[i] = gaussian(rng) * 5; + return c; + }); + const vectors: Float32Array[] = []; + for (let j = 0; j < count; j++) { + const c = centers[j % clusters]!; + const v = new Float32Array(dim); + for (let i = 0; i < dim; i++) v[i] = c[i]! + gaussian(rng); + vectors.push(v); + } + return { vectors, centers }; +} + +// ── Exact cosine top-k ground truth ──────────────────────────────────────────── +function normalized(v: Float32Array): Float32Array { + let s = 0; + for (let i = 0; i < v.length; i++) s += v[i]! * v[i]!; + const inv = 1 / Math.sqrt(s); + const out = new Float32Array(v.length); + for (let i = 0; i < v.length; i++) out[i] = v[i]! * inv; + return out; +} + +function exactTopK(db: Float32Array[], query: Float32Array, k: number): number[] { + const q = normalized(query); + const scored = db.map((v, idx) => { + const u = normalized(v); + let dot = 0; + for (let i = 0; i < u.length; i++) dot += u[i]! * q[i]!; + return { idx, dot }; + }); + scored.sort((a, b) => b.dot - a.dot); + return scored.slice(0, k).map((s) => s.idx); +} + +function recallAt(approx: Int32Array, exact: number[], k: number): number { + const truth = new Set(exact.slice(0, k)); + let hit = 0; + for (let i = 0; i < Math.min(k, approx.length); i++) if (truth.has(approx[i]!)) hit++; + return hit / k; +} + +// ── Benchmark one configuration ──────────────────────────────────────────────── +interface Row { + label: string; + nprobe?: number; + recall1: number; + recall10: number; + recall100: number; + qps: number; + speedupVsFlat: number; +} + +function benchIndex( + label: string, + index: TurboQuantIndex, + queries: Float32Array[], + exact: number[][], + nprobe?: number, +): Omit { + const K = 100; + const opts = nprobe === undefined ? {} : { nprobe }; + const approxAll: Int32Array[] = []; + const tSearch = performance.now(); + for (const q of queries) approxAll.push(index.search(q, K, opts).indices); + const searchSecs = (performance.now() - tSearch) / 1000; + + let r1 = 0; + let r10 = 0; + let r100 = 0; + for (let i = 0; i < queries.length; i++) { + r1 += recallAt(approxAll[i]!, exact[i]!, 1); + r10 += recallAt(approxAll[i]!, exact[i]!, 10); + r100 += recallAt(approxAll[i]!, exact[i]!, 100); + } + const nq = queries.length; + const row: Omit = { + label, + recall1: r1 / nq, + recall10: r10 / nq, + recall100: r100 / nq, + qps: nq / searchSecs, + }; + if (nprobe !== undefined) row.nprobe = nprobe; + return row; +} + +// ── Run ──────────────────────────────────────────────────────────────────────── +function main(): void { + const dim = Number(process.env.DIM ?? 768); + const n = Number(process.env.N ?? 20000); + const nq = Number(process.env.NQ ?? 200); + const clusters = Number(process.env.CLUSTERS ?? 64); + const nlist = Number(process.env.NLIST ?? 128); + const rng = mulberry32(42); + + process.stdout.write( + `quantvec IVF benchmark — dim=${dim} n=${n} queries=${nq} clusters=${clusters} nlist=${nlist} (cosine, 4-bit)\n`, + ); + + const { vectors: db } = makeClustered(n, dim, clusters, rng); + const { vectors: queries } = makeClustered(nq, dim, clusters, mulberry32(7)); + process.stdout.write('computing exact float32 ground truth…\n'); + const exact = queries.map((q) => exactTopK(db, q, 100)); + + // Flat baseline (scalar path — the IVF scan is scalar too, so the comparison is fair). + const flat = new TurboQuantIndex({ dim, bits: 4, metric: 'cosine', seed: 1, wasm: false }); + flat.add(db); + const flatRow: Row = { ...benchIndex('flat', flat, queries, exact), speedupVsFlat: 1 }; + + // IVF index: trained from the same single batch. + const tTrain = performance.now(); + const ivf = new TurboQuantIndex({ dim, bits: 4, metric: 'cosine', seed: 1, ivf: { nlist } }); + ivf.add(db); + const trainSecs = (performance.now() - tTrain) / 1000; + if (!ivf.ivfActive) throw new Error('IVF did not train — first batch smaller than nlist?'); + + const sweep = [1, 2, 4, 8, 16, nlist].filter((p, i, arr) => arr.indexOf(p) === i && p <= nlist); + const rows: Row[] = [flatRow]; + for (const nprobe of sweep) { + const r = benchIndex(`ivf@${nprobe}`, ivf, queries, exact, nprobe); + rows.push({ ...r, speedupVsFlat: r.qps / flatRow.qps }); + } + + // Human table. + process.stdout.write(`\nbuild+train: ${trainSecs.toFixed(2)}s for ${n} vectors\n`); + process.stdout.write('\nconfig | recall@1 | recall@10 | recall@100 | QPS | speedup\n'); + process.stdout.write('----------|----------|-----------|------------|--------|--------\n'); + for (const r of rows) { + process.stdout.write( + `${r.label.padEnd(9)} | ${r.recall1.toFixed(3)} | ${r.recall10.toFixed(3)} | ${r.recall100.toFixed( + 3, + )} | ${Math.round(r.qps).toString().padStart(6)} | ${r.speedupVsFlat.toFixed(2)}x\n`, + ); + } + + // METRIC lines (autoresearch protocol). + process.stdout.write('\n'); + process.stdout.write(`METRIC flat_qps=${Math.round(flatRow.qps)}\n`); + for (const r of rows) { + if (r.nprobe === undefined) continue; + process.stdout.write(`METRIC ivf_recall_at10_nprobe${r.nprobe}=${r.recall10.toFixed(4)}\n`); + process.stdout.write(`METRIC ivf_qps_nprobe${r.nprobe}=${Math.round(r.qps)}\n`); + } + + // JSON results. + const outDir = join(process.cwd(), 'benchmarks', 'results'); + mkdirSync(outDir, { recursive: true }); + const outPath = join(outDir, `ivf-d${dim}.json`); + writeFileSync( + outPath, + JSON.stringify( + { + dim, + n, + nq, + clusters, + nlist, + metric: 'cosine', + bits: 4, + trainSecs, + generatedAt: new Date().toISOString(), + rows, + }, + null, + 2, + ), + ); + process.stdout.write(`\nwrote ${outPath}\n`); +} + +main(); diff --git a/benchmarks/results/ivf-d768.json b/benchmarks/results/ivf-d768.json new file mode 100644 index 0000000..fa66f26 --- /dev/null +++ b/benchmarks/results/ivf-d768.json @@ -0,0 +1,75 @@ +{ + "dim": 768, + "n": 20000, + "nq": 200, + "clusters": 64, + "nlist": 128, + "metric": "cosine", + "bits": 4, + "trainSecs": 17.951041874999994, + "generatedAt": "2026-06-10T12:35:16.896Z", + "rows": [ + { + "label": "flat", + "recall1": 0.355, + "recall10": 0.6029999999999996, + "recall100": 0.7497500000000001, + "qps": 52.78636802603001, + "speedupVsFlat": 1 + }, + { + "label": "ivf@1", + "recall1": 0.24, + "recall10": 0.3864999999999999, + "recall100": 0.4408500000000001, + "qps": 1205.0945395762583, + "nprobe": 1, + "speedupVsFlat": 22.82965440967642 + }, + { + "label": "ivf@2", + "recall1": 0.305, + "recall10": 0.5320000000000001, + "recall100": 0.6422999999999999, + "qps": 1090.9819903257126, + "nprobe": 2, + "speedupVsFlat": 20.667873754597544 + }, + { + "label": "ivf@4", + "recall1": 0.355, + "recall10": 0.6019999999999996, + "recall100": 0.7372500000000002, + "qps": 851.6927813487874, + "nprobe": 4, + "speedupVsFlat": 16.134710782314116 + }, + { + "label": "ivf@8", + "recall1": 0.36, + "recall10": 0.6029999999999996, + "recall100": 0.7495499999999999, + "qps": 600.3408441145751, + "nprobe": 8, + "speedupVsFlat": 11.373028048047084 + }, + { + "label": "ivf@16", + "recall1": 0.355, + "recall10": 0.6029999999999996, + "recall100": 0.7498500000000001, + "qps": 377.1228478449929, + "nprobe": 16, + "speedupVsFlat": 7.144322709587182 + }, + { + "label": "ivf@128", + "recall1": 0.355, + "recall10": 0.6029999999999996, + "recall100": 0.7497500000000001, + "qps": 59.81218376773218, + "nprobe": 128, + "speedupVsFlat": 1.1330990557682923 + } + ] +} \ No newline at end of file diff --git a/docs/api-reference.md b/docs/api-reference.md index 169804b..c58af56 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -10,31 +10,33 @@ A growable, positional flat quantized index. new TurboQuantIndex(options: TurboQuantIndexOptions) ``` -| Option | Type | Default | Notes | -| ----------- | ---------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `dim` | `number` | — | positive multiple of 8 | -| `bits` | `2 \| 3 \| 4` | `4` | quantizer bit-width | -| `metric` | `'cosine' \| 'dot' \| 'euclidean'` | `'cosine'` | default ranking metric | -| `seed` | `number` | `0` | rotation RNG seed (finite; truncated to an integer) | -| `calibrate` | `boolean` | `false` | opt-in TQ+ per-coordinate calibration (fit from the first ≥1000-vector add; data-dependent) | -| `wasm` | `boolean` | `true` | use the WASM scoring kernel when available (exact; auto-falls back to the scalar scan) | -| `fastscan` | `boolean` | `false` | use the v128 FastScan kernel (4-bit only; approximate SIMD ranking + exact rescore; falls back to the exact kernel when `bits ≠ 4` or WASM unavailable) | +| Option | Type | Default | Notes | +| ----------- | ------------------------------------ | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `dim` | `number` | — | positive multiple of 8 | +| `bits` | `2 \| 3 \| 4` | `4` | quantizer bit-width | +| `metric` | `'cosine' \| 'dot' \| 'euclidean'` | `'cosine'` | default ranking metric | +| `seed` | `number` | `0` | rotation RNG seed (finite; truncated to an integer) | +| `calibrate` | `boolean` | `false` | opt-in TQ+ per-coordinate calibration (fit from the first ≥1000-vector add; data-dependent) | +| `wasm` | `boolean` | `true` | use the WASM scoring kernel when available (exact; auto-falls back to the scalar scan) | +| `fastscan` | `boolean` | `false` | use the v128 FastScan kernel (4-bit only; approximate SIMD ranking + exact rescore; falls back to the exact kernel when `bits ≠ 4` or WASM unavailable) | +| `ivf` | `{ nlist: number; nprobe?: number }` | off | opt-in IVF coarse quantizer: `nlist ∈ [2, 2^22]` k-means cells trained from the first ≥ nlist-vector add and frozen; queries probe `nprobe` cells (default ⌈nlist/8⌉). Whole-database WASM/FastScan kernels are bypassed while active | **Methods & getters** -| Member | Signature | Description | -| ---------------------------------------------------------- | ----------------------------------------------------------------------------- | --------------------------------------------- | -| `add` | `(vectors: Float32Array \| number[][] \| Float32Array[]) => void` | append a batch | -| `addOne` | `(vec: Float32Array \| number[]) => void` | append one | -| `search` | `(query: Float32Array, k: number, opts?: IndexSearchOptions) => SearchResult` | k nearest | -| `swapRemove` | `(i: number) => void` | O(1) delete; moves the last row into slot `i` | -| `clear` | `() => void` | drop all vectors (keeps capacity) | -| `toBytes` | `() => Uint8Array` | serialize | -| `TurboQuantIndex.fromBytes` | `(bytes: Uint8Array) => TurboQuantIndex` | deserialize (static) | -| `size` / `dim` / `bits` / `metric` / `seed` / `calibrated` | getters | live count + config + whether TQ+ is active | - -`IndexSearchOptions`: `{ metric?: Distance; mask?: Uint8Array \| boolean[] }` (mask length = `size`, -positional). `SearchResult`: `{ indices: Int32Array; scores: Float32Array }`. +| Member | Signature | Description | +| ------------------------------------------------------------------------ | ----------------------------------------------------------------------------- | ------------------------------------------------ | +| `add` | `(vectors: Float32Array \| number[][] \| Float32Array[]) => void` | append a batch | +| `addOne` | `(vec: Float32Array \| number[]) => void` | append one | +| `search` | `(query: Float32Array, k: number, opts?: IndexSearchOptions) => SearchResult` | k nearest | +| `swapRemove` | `(i: number) => void` | O(1) delete; moves the last row into slot `i` | +| `clear` | `() => void` | drop all vectors (keeps capacity) | +| `toBytes` | `() => Uint8Array` | serialize | +| `TurboQuantIndex.fromBytes` | `(bytes: Uint8Array) => TurboQuantIndex` | deserialize (static) | +| `size` / `dim` / `bits` / `metric` / `seed` / `calibrated` / `ivfActive` | getters | live count + config + whether TQ+/IVF are active | + +`IndexSearchOptions`: `{ metric?: Distance; mask?: Uint8Array \| boolean[]; nprobe?: number }` +(mask length = `size`, positional; `nprobe` overrides the IVF probe breadth and is ignored when +IVF is not active). `SearchResult`: `{ indices: Int32Array; scores: Float32Array }`. ## `IdMapIndex` @@ -44,19 +46,19 @@ Stable-id layer over `TurboQuantIndex`. `Id extends number | string | bigint`, d new IdMapIndex(options: TurboQuantIndexOptions) ``` -| Member | Signature | Description | -| ---------------------------------------------------------- | --------------------------------------------------------------------------------------- | ------------------------------------------- | -| `addWithIds` | `(ids: readonly Id[], vectors: Float32Array \| number[][] \| Float32Array[]) => void` | append with ids | -| `search` | `(query: Float32Array, k: number, opts?: IdMapSearchOptions) => IdSearchResult` | k nearest, by id | -| `has` | `(id: Id) => boolean` | membership | -| `remove` | `(id: Id) => void` | O(1) delete by id | -| `ids` | `() => Id[]` | snapshot of all ids (slot order) | -| `clear` | `() => void` | empty the index | -| `toBytes` | `() => Uint8Array` | serialize | -| `IdMapIndex.fromBytes` | `(bytes: Uint8Array) => IdMapIndex` | deserialize (static; assert `Id`) | -| `size` / `dim` / `bits` / `metric` / `seed` / `calibrated` | getters | live count + config + whether TQ+ is active | - -`IdMapSearchOptions`: `{ metric?: Distance; filter?: (id: Id) => boolean }`. +| Member | Signature | Description | +| ------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ------------------------------------------------ | +| `addWithIds` | `(ids: readonly Id[], vectors: Float32Array \| number[][] \| Float32Array[]) => void` | append with ids | +| `search` | `(query: Float32Array, k: number, opts?: IdMapSearchOptions) => IdSearchResult` | k nearest, by id | +| `has` | `(id: Id) => boolean` | membership | +| `remove` | `(id: Id) => void` | O(1) delete by id | +| `ids` | `() => Id[]` | snapshot of all ids (slot order) | +| `clear` | `() => void` | empty the index | +| `toBytes` | `() => Uint8Array` | serialize | +| `IdMapIndex.fromBytes` | `(bytes: Uint8Array) => IdMapIndex` | deserialize (static; assert `Id`) | +| `size` / `dim` / `bits` / `metric` / `seed` / `calibrated` / `ivfActive` | getters | live count + config + whether TQ+/IVF are active | + +`IdMapSearchOptions`: `{ metric?: Distance; filter?: (id: Id) => boolean; nprobe?: number }`. `IdSearchResult`: `{ ids: Id[]; scores: Float32Array }`. ## `quantvec/node` diff --git a/docs/architecture.md b/docs/architecture.md index 229136a..1f9dfb4 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -63,7 +63,8 @@ a bounded min-heap. This pure-TypeScript scalar kernel is the **correctness orac | `core/rng`, `core/rotation`, `core/fwht` | seeded RNG, rotation (Householder QR, or FWHT for power-of-two dims) | | `core/beta`, `core/codebook` | Beta pdf/cdf/quantile, Lloyd-Max codebooks per `(dim, bits)` | | `core/encode`, `core/pack`, `core/calibrate` | normalize→rotate→(TQ+)→quantize→scale; bit-pack; calibration fit | -| `core/search`, `core/topk`, `core/metrics` | nibble-LUT scan, bounded heap, distance math | +| `core/search`, `core/topk`, `core/metrics` | nibble-LUT scan (flat + probed-slot subset), bounded heap, distance math | +| `core/kmeans`, `index/coarse` | seeded k-means++ / Lloyd; IVF cell structure (centroids + posting lists in lockstep with swap-remove) | | `wasm/kernel` + `assembly/` | WASM kernels: exact f64 scoreInto (bit-identical to scalar oracle) + v128 FastScan (blocked-nibble swizzle + rescore) | | `ergonomic/collection`, `ergonomic/filter` | `createCollection`, `Collection

`, `must`/`should`/`must_not` filter DSL | | `index/turboquant-index` | growable positional flat index | @@ -72,5 +73,8 @@ a bounded min-heap. This pure-TypeScript scalar kernel is the **correctness orac ## Scope -quantvec is a **flat** quantized index — search is an O(n) scan, excellent to ~1–10M vectors. It is -not an HNSW graph; an IVF/coarse-quantizer layer for larger corpora is on the [roadmap](roadmap.md). +quantvec is a **flat** quantized index — search is an O(n) scan, excellent to ~1–10M vectors — with +an opt-in **IVF coarse quantizer** (`ivf: { nlist }`): k-means cells trained from the first batch, +queries probe only the `nprobe` nearest cells (sublinear scan; the whole-database WASM kernels are +bypassed while IVF is active — a cell-resident kernel is a future wave). It is not an HNSW graph +([roadmap](roadmap.md) non-goal). diff --git a/docs/benchmarks.md b/docs/benchmarks.md index fca13c1..3119927 100644 --- a/docs/benchmarks.md +++ b/docs/benchmarks.md @@ -87,3 +87,21 @@ The SIMD scan cost is O(n) while the rescore-pool overhead is constant, so the g embeddings; neutral-to-negative on well-conditioned synthetic data. - **FWHT** is used automatically for power-of-two dims (128, 256, 512, 768, 1024, 1536…); O(d·log d) vs O(d²) for the dense rotation — ~25× faster encode at no recall cost. + +## IVF coarse quantizer (synthetic, clustered) + +`npm run bench:ivf` — 20k × 768-d Gaussian-mixture corpus (64 clusters), cosine, 4-bit, `nlist=128`, +sweeping `nprobe` against the flat scalar baseline (env knobs: `DIM`, `N`, `NQ`, `CLUSTERS`, `NLIST`): + +| config | recall@10 | QPS | speedup vs flat | +| ------- | --------- | ---- | --------------- | +| flat | 0.603 | 53 | 1.0× | +| ivf@1 | 0.387 | 1205 | 22.8× | +| ivf@4 | 0.602 | 852 | 16.1× | +| ivf@8 | 0.603 | 600 | 11.4× | +| ivf@128 | 0.603 | 60 | 1.1× | + +Recall is measured against the exact **float32** ground truth, so the 0.603 ceiling is the 4-bit +quantizer's own recall (the flat row) — IVF reaches that ceiling while probing 6% of the cells +(`nprobe=8`), and `nprobe = nlist` reproduces the flat scan exactly (the `searchSlots` oracle). The +speedup grows with corpus size: the probed-cell scan is O(n·nprobe/nlist) while flat is O(n). diff --git a/docs/guide.md b/docs/guide.md index 44e870e..6c90dce 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -12,6 +12,9 @@ An end-to-end tour of the two index classes, metrics, filtering, persistence, an - **`Collection

`** — the ergonomic, qdrant-style layer: payloads + a structured filter DSL on top of `IdMapIndex` (see [Collections](#collections-payloads--filters)). +All three layers accept the same scaling knobs: `calibrate` (TQ+), `fastscan` (4-bit SIMD), and +`ivf` (coarse-quantized sublinear search — see [IVF](#ivf-coarse-quantizer)). + ## Adding vectors ```ts @@ -105,13 +108,13 @@ try { } ``` -| Error | Sample codes | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------- | -| `IndexError` | `INVALID_DIM`, `INVALID_BITS`, `INVALID_SEED`, `INVALID_VECTOR`, `INVALID_LENGTH`, `INVALID_INDEX`, `EMPTY`, `WRONG_KIND` | -| `IdMapError` | `DUPLICATE_ID`, `UNKNOWN_ID`, `COUNT_MISMATCH`, `INVALID_ID_TYPE`, `INVALID_VECTOR`, `EMPTY`, `WRONG_KIND` | -| `DeserializeError` | `BAD_MAGIC`, `BAD_VERSION`, `BAD_KIND`, `BAD_DIM`, `BAD_SEED`, `BAD_LENGTH`, `BAD_ID`, `TOO_SHORT` | -| `EncodeError` | `ZERO_VECTOR`, `INVALID_LENGTH`, `DEGENERATE` | -| `SearchError` | `INVALID_K`, `ZERO_QUERY`, `INVALID_MASK` | +| Error | Sample codes | +| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `IndexError` | `INVALID_DIM`, `INVALID_BITS`, `INVALID_SEED`, `INVALID_VECTOR`, `INVALID_LENGTH`, `INVALID_INDEX`, `INVALID_NLIST`, `INVALID_NPROBE`, `EMPTY`, `WRONG_KIND` | +| `IdMapError` | `DUPLICATE_ID`, `UNKNOWN_ID`, `COUNT_MISMATCH`, `INVALID_ID_TYPE`, `INVALID_VECTOR`, `EMPTY`, `WRONG_KIND` | +| `DeserializeError` | `BAD_MAGIC`, `BAD_VERSION`, `BAD_KIND`, `BAD_DIM`, `BAD_SEED`, `BAD_LENGTH`, `BAD_ID`, `BAD_IVF`, `TOO_SHORT` | +| `EncodeError` | `ZERO_VECTOR`, `INVALID_LENGTH`, `DEGENERATE` | +| `SearchError` | `INVALID_K`, `ZERO_QUERY`, `INVALID_MASK`, `INVALID_SLOT` | ## Calibration (TQ+) @@ -136,6 +139,37 @@ distribution (e.g. anti-correlated with a tight calibration cluster) may not be `add` rejects it with `EncodeError` code `DEGENERATE`. If your data drifts that far, rebuild the index without `calibrate`. +## IVF (coarse quantizer) + +For large corpora the O(n) flat scan becomes the bottleneck. Enabling `ivf` partitions the corpus +into `nlist` k-means cells; each query ranks the cell centroids and scans only the `nprobe` nearest +cells: + +```ts +const idx = new TurboQuantIndex({ dim: 768, ivf: { nlist: 256 } }); +idx.add(corpus); // first add of ≥ nlist vectors trains + freezes the cells +idx.ivfActive; // → true +idx.search(q, 10); // probes ⌈nlist/8⌉ cells by default +idx.search(q, 10, { nprobe: 64 }); // recall/speed knob, per query +``` + +- **Training**: the cells are fit (seeded k-means++, spherical for cosine/dot, L2 for euclidean) + from the **first** non-empty add and frozen for the index's lifetime — exactly the calibration + contract. The hard minimum is `nlist` vectors (≥ ~32·nlist recommended); a smaller first batch + freezes the index flat forever. Choose `nlist ≈ √n` as a starting point. +- **Exactness**: the probed-cell scan uses the same exact scalar kernel as the flat path, so + `nprobe = nlist` reproduces the flat scan bit-for-bit; smaller `nprobe` trades recall for speed + (measured: ~11× QPS at the flat scan's recall with `nprobe = nlist/16` on clustered data). +- **Mutations**: `add`/`addOne` assign new vectors to their nearest cell; `swapRemove`/`remove` + keep the posting lists in lockstep (full parity with the flat index); `clear()` keeps the trained + cells. Serialization round-trips the whole structure (format v2). +- **Trade-offs**: while IVF is active the whole-database WASM/FastScan kernels are bypassed (a + cell-resident kernel is a future wave), and `calibrate`'s `DEGENERATE` caveat applies to cell + quality too: heavy data drift after training degrades the partition — rebuild to retrain. + +The same knobs flow through the other layers: `new IdMapIndex({ dim, ivf: { nlist } })` and +`createCollection({ ..., ivf: { nlist } })`, with `nprobe` accepted by their search options. + ## Collections (payloads + filters) `createCollection` is the highest-level API — store points with typed payloads and query with a diff --git a/docs/roadmap.md b/docs/roadmap.md index ae2bc27..34fc894 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -43,11 +43,22 @@ merges only when the gate is green (typecheck, lint, tests, coverage) and review `benchmarks/results/`. GloVe-200 exercises the dense Householder rotation (dim=200, non-power-of-two); dbpedia-OpenAI-100k exercises the FWHT path (dim=1536, power-of-two). +- **IVF / coarse quantizer** — opt-in (`ivf: { nlist, nprobe? }`) sublinear search for large + corpora: seeded k-means++ cells (spherical for cosine/dot, L2 for euclidean) trained from the + first ≥ nlist-vector batch and frozen (same contract as calibration), posting lists kept in + lockstep with swap-remove (full remove parity), per-query `nprobe` knob, and serialization in + format v2. The probed-cell scan reuses the exact scalar kernel, so `nprobe = nlist` reproduces + the flat scan bit-for-bit. Measured (20k × 768-d clustered, 4-bit): **11.4× QPS at the flat + scan's recall** (nprobe = 8/128), 22.8× at nprobe = 1. Full parity through `IdMapIndex` and + `Collection` (`ivf` config + `nprobe` search param). + ## Planned -- **IVF / coarse quantizer** for sublinear search on 10M+ corpora. +_(none — all planned waves have shipped)_ ## Non-goals (for now) -- An HNSW graph index — quantvec is deliberately a flat quantized index; IVF is the planned path to scale. -- A trained/learned codebook — the data-oblivious, zero-training property is the point. +- An HNSW graph index — quantvec is deliberately a flat quantized index; IVF (shipped) is the path to scale. +- A trained/learned codebook — the data-oblivious, zero-training property is the point. (The IVF + coarse quantizer trains only the cell _partition_, never the per-coordinate codebook — codes stay + data-oblivious.) diff --git a/docs/serialization.md b/docs/serialization.md index 56bf837..5a2b9e8 100644 --- a/docs/serialization.md +++ b/docs/serialization.md @@ -12,7 +12,7 @@ All multi-byte fields are little-endian. The header is 24 bytes: | Offset | Size | Field | Notes | | ------ | ---- | ------- | ---------------------------------------- | | 0 | 4 | magic | `"QVEC"` (`0x51 0x56 0x45 0x43`) | -| 4 | 1 | version | `1` | +| 4 | 1 | version | `2` | | 5 | 1 | kind | `0` = positional, `1` = id-keyed | | 6 | 1 | metric | `0` = dot, `1` = cosine, `2` = euclidean | | 7 | 1 | bits | `2`, `3`, or `4` | @@ -23,11 +23,15 @@ All multi-byte fields are little-endian. The header is 24 bytes: Body, immediately after the header: ``` -codes : ⌈n·dim·bits/8⌉ bytes (tightly bit-packed, LSB-first; dim is a multiple of 8 - so this is exact — no padding waste) -scales : n · f32 (per-vector RaBitQ scale) -norms : n · f32 (per-vector ‖v‖) -ids : n × tagged id (id-keyed only) +codes : ⌈n·dim·bits/8⌉ bytes (tightly bit-packed, LSB-first; dim is a multiple of 8 + so this is exact — no padding waste) +scales : n · f32 (per-vector RaBitQ scale) +norms : n · f32 (per-vector ‖v‖) +calibration : flag u8 ∈ {0,1} (1 → shift dim·f32 + scale dim·f32) +ivf : flag u8 ∈ {0,1} (1 → nlist u32 + nprobe u32 + + centroids nlist·dim·f32 + listForSlot n·u32; + posting lists are rebuilt from listForSlot on load) +ids : n × tagged id (id-keyed only) ``` Codes are stored at true 2/3/4 bits per coordinate, so the serialized index is 7.9–15.7× smaller than @@ -50,16 +54,21 @@ Each id is `tag (u8)` then payload: - `dim` is a positive multiple of 8 and `seed` is finite; - the declared body size fits within the buffer (so a crafted huge `n` can't trigger an out-of-bounds read or an out-of-memory allocation — it's rejected first); +- the calibration section: finite shift/scale with no zero scale (a divisor at query time); +- the ivf section: `nlist ∈ [2, 2^22]`, `nprobe ∈ [1, nlist]`, section bounds checked before any + allocation, finite centroids, and every `listForSlot` entry `< nlist`; - every id: bounds-checked length, **valid UTF-8** (fatal decode), **canonical** bigint decimal, and **no duplicate ids** (a collision would silently break the id↔slot bijection). Loading the wrong kind (e.g. positional bytes into `IdMapIndex.fromBytes`) throws `WRONG_KIND`; any structural problem throws a `DeserializeError` with a specific `.code` (`BAD_MAGIC`, `BAD_VERSION`, `BAD_KIND`, `BAD_METRIC`, `BAD_BITS`, `BAD_DIM`, `BAD_SEED`, `BAD_LENGTH`, -`BAD_ID`, `TOO_SHORT`). +`BAD_ID`, `BAD_CALIBRATION`, `BAD_IVF`, `TOO_SHORT`). ## Compatibility notes +- Version 2 (current) added the ivf presence byte and section. Per the pre-1.0 policy there are no + legacy readers: v1 buffers are rejected with `BAD_VERSION` — re-serialize with the current library. - The id type is **not** stored; pass it to `IdMapIndex.fromBytes` and ensure it matches. - The on-disk `codes` section is bit-packed; the index still holds one byte per code in memory. In-memory packing (and a SIMD scan over packed codes) is on the [roadmap](roadmap.md). diff --git a/docs/worklog/DECISIONS.md b/docs/worklog/DECISIONS.md index c13522b..3067270 100644 --- a/docs/worklog/DECISIONS.md +++ b/docs/worklog/DECISIONS.md @@ -2,6 +2,25 @@ Concise record of locked decisions and their rationale. Newest first. +## D-016 · IVF as an option on TurboQuantIndex (not a class); format v2; train-on-first-batch + +The IVF coarse quantizer ships as `ivf: { nlist, nprobe? }` on `TurboQuantIndexOptions`, NOT a +separate index class: `IdMapIndex` wraps `TurboQuantIndex` and `Collection` wraps `IdMapIndex`, so a +constructor option gives all three layers full parity (config + per-query `nprobe` + remove + ser/de) +with two passthrough lines each. State lives in `index/coarse.ts` (`CoarseQuantizer`: centroids + +posting lists with slot→list/slot→pos arrays for O(1) swap-remove patching) over `core/kmeans.ts` +(seeded k-means++/Lloyd; spherical for cosine/dot, L2 for euclidean; domain-separated RNG stream from +the index seed). Training mirrors calibration: fit-and-freeze from the first non-empty batch when it +has ≥ `nlist` vectors (the hard k-means floor — predictable from the user's own config; quality +guidance ≥ ~32·nlist lives in docs), else frozen flat forever. The probed-cell scan is the exact +scalar kernel (`searchSlots` in `core/search.ts`, sharing `prepareScan` validation with `searchFlat`) +→ `nprobe = nlist` ≡ flat scan bit-for-bit (the IVF analog of the WASM≡scalar oracle); the +whole-database WASM/FastScan kernels are bypassed while IVF is active (cell-resident kernel = future +wave). Serialization: format `VERSION` 1 → 2 with an always-written ivf presence byte (mirror of the +calibration byte) + `nlist/nprobe/centroids/listForSlot` (postings rebuilt on load); v2-only readers +per D-010 — a v1 reader rejects v2 cleanly with `BAD_VERSION` instead of misparsing. Measured +(20k × 768-d clustered, 4-bit): 11.4× QPS at the flat scan's recall (nprobe = nlist/16). + ## D-015 · WASM kernel: exact f64 + resident codes, not approximate FastScan (first) The WASM acceleration (`assembly/index.ts` + `src/wasm/kernel.ts`) ships as an EXACT kernel: codes diff --git a/package-lock.json b/package-lock.json index 4b47ca3..0d8c0e9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "quantvec", - "version": "0.0.2", + "version": "0.0.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "quantvec", - "version": "0.0.2", + "version": "0.0.3", "license": "Apache-2.0", "devDependencies": { "@eslint/js": "^10.0.1", diff --git a/package.json b/package.json index 9b3a07e..3048ce9 100644 --- a/package.json +++ b/package.json @@ -71,6 +71,7 @@ "test:coverage": "vitest run --coverage --test-timeout=30000", "bench": "vitest bench --run", "bench:flat": "npx tsx benchmarks/flat.ts", + "bench:ivf": "npx tsx benchmarks/ivf.ts", "bench:real": "node benchmarks/download-siftsmall.mjs && npx tsx benchmarks/real.ts", "bench:glove": "node benchmarks/download-glove.mjs && npx tsx benchmarks/glove.ts", "bench:openai": "node benchmarks/download-openai.mjs && npx tsx benchmarks/openai.ts", diff --git a/src/core/kmeans.test.ts b/src/core/kmeans.test.ts new file mode 100644 index 0000000..0cf2afc --- /dev/null +++ b/src/core/kmeans.test.ts @@ -0,0 +1,221 @@ +import { describe, expect, it } from 'vitest'; +import { kmeans, nearestCentroid, KMeansError } from './kmeans'; +import type { KMeansResult } from './kmeans'; +import { createRng } from './rng'; + +/** m·dim row-major blob data: `centers.length` tight gaussian blobs. */ +function blobs( + centers: number[][], + perBlob: number, + dim: number, + seed: number, + noise = 0.05, +): { data: Float32Array; m: number; blobOf: Int32Array } { + const rng = createRng(seed); + const m = centers.length * perBlob; + const data = new Float32Array(m * dim); + const blobOf = new Int32Array(m); + let r = 0; + for (let b = 0; b < centers.length; b++) { + for (let j = 0; j < perBlob; j++, r++) { + blobOf[r] = b; + for (let i = 0; i < dim; i++) { + data[r * dim + i] = centers[b]![i]! + rng.nextGaussian() * noise; + } + } + } + return { data, m, blobOf }; +} + +function run( + data: Float32Array, + m: number, + k: number, + dim: number, + seed: number, + spherical = false, +): KMeansResult { + return kmeans(data, m, { k, dim, rng: createRng(seed), spherical }); +} + +describe('kmeans — validation', () => { + const data = new Float32Array(8 * 4); + + it('rejects a bad dim', () => { + let err: unknown; + try { + kmeans(data, 8, { k: 2, dim: 0, rng: createRng(1), spherical: false }); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(KMeansError); + expect((err as KMeansError).code).toBe('INVALID_DIM'); + }); + + it('rejects a length that is not m·dim', () => { + let err: unknown; + try { + kmeans(data, 7, { k: 2, dim: 4, rng: createRng(1), spherical: false }); + } catch (e) { + err = e; + } + expect((err as KMeansError).code).toBe('INVALID_LENGTH'); + }); + + it.each([1, 9, 2.5])('rejects k = %s outside [2, m] (m = 8)', (k) => { + let err: unknown; + try { + kmeans(data, 8, { k, dim: 4, rng: createRng(1), spherical: false }); + } catch (e) { + err = e; + } + expect((err as KMeansError).code).toBe('INVALID_K'); + }); +}); + +describe('kmeans — clustering behavior', () => { + it('recovers two well-separated blobs exactly', () => { + const { data, m, blobOf } = blobs( + [ + [10, 0, 0, 0], + [-10, 0, 0, 0], + ], + 50, + 4, + 11, + ); + const { centroids, assignments } = run(data, m, 2, 4, 7); + // Every row in the same blob lands in the same cluster, blobs in different ones. + const clusterOfBlob = [assignments[0]!, assignments[50]!]; + expect(clusterOfBlob[0]).not.toBe(clusterOfBlob[1]); + for (let r = 0; r < m; r++) expect(assignments[r]).toBe(clusterOfBlob[blobOf[r]!]!); + // Centroids sit on the blob centers. + for (const c of clusterOfBlob) { + expect(Math.abs(Math.abs(centroids[c! * 4]!) - 10)).toBeLessThan(0.5); + } + }); + + it('is deterministic: same data + seed yields identical centroids and assignments', () => { + const { data, m } = blobs( + [ + [5, 5], + [-5, 5], + [0, -5], + ], + 40, + 2, + 21, + ); + const a = run(data, m, 3, 2, 99); + const b = run(data, m, 3, 2, 99); + expect(Array.from(a.centroids)).toEqual(Array.from(b.centroids)); + expect(Array.from(a.assignments)).toEqual(Array.from(b.assignments)); + expect(a.iterations).toBe(b.iterations); + }); + + it('repairs empty clusters: k=3 over 2 tight blobs still yields 3 non-empty clusters', () => { + const { data, m } = blobs( + [ + [10, 0], + [-10, 0], + ], + 30, + 2, + 31, + 0.5, + ); + const { centroids, assignments } = run(data, m, 3, 2, 13); + const counts = [0, 0, 0]; + for (let r = 0; r < m; r++) counts[assignments[r]!]!++; + expect(counts.every((c) => c > 0)).toBe(true); + for (const x of centroids) expect(Number.isFinite(x)).toBe(true); + }); + + it('spherical mode returns unit-norm centroids and clusters directions', () => { + // Two antipodal direction blobs on the unit circle (pre-normalized rows). + const raw = blobs( + [ + [1, 0, 0, 0], + [-1, 0, 0, 0], + ], + 40, + 4, + 41, + 0.02, + ); + for (let r = 0; r < raw.m; r++) { + let n = 0; + for (let i = 0; i < 4; i++) n += raw.data[r * 4 + i]! ** 2; + const inv = 1 / Math.sqrt(n); + for (let i = 0; i < 4; i++) raw.data[r * 4 + i] = raw.data[r * 4 + i]! * inv; + } + const { centroids, assignments } = run(raw.data, raw.m, 2, 4, 5, true); + for (let c = 0; c < 2; c++) { + let n = 0; + for (let i = 0; i < 4; i++) n += centroids[c * 4 + i]! ** 2; + expect(Math.sqrt(n)).toBeCloseTo(1, 6); + } + expect(assignments[0]).not.toBe(assignments[40]); + }); + + it('terminates within the iteration cap and reports iterations', () => { + const { data, m } = blobs( + [ + [3, 0], + [-3, 0], + ], + 25, + 2, + 51, + ); + const res = kmeans(data, m, { + k: 2, + dim: 2, + rng: createRng(3), + spherical: false, + maxIterations: 4, + }); + expect(res.iterations).toBeGreaterThanOrEqual(1); + expect(res.iterations).toBeLessThanOrEqual(4); + }); + + it('handles k === m (every point its own centroid)', () => { + const data = Float32Array.from([0, 0, 5, 0, 0, 5, 5, 5]); + const { assignments } = run(data, 4, 4, 2, 17); + expect(new Set(Array.from(assignments)).size).toBe(4); + }); + + it('handles duplicate rows (zero D² mass in seeding) without NaN', () => { + const data = new Float32Array(6 * 2).fill(1); // 6 identical points + const { centroids } = run(data, 6, 2, 2, 23); + for (const x of centroids) expect(Number.isFinite(x)).toBe(true); + }); +}); + +describe('nearestCentroid', () => { + it('matches brute force for both affinities', () => { + const rng = createRng(77); + const dim = 8; + const k = 5; + const centroids = new Float32Array(k * dim).map(() => rng.nextGaussian()); + const v = new Float32Array(dim).map(() => rng.nextGaussian()); + for (const spherical of [false, true]) { + let best = 0; + let bestKey = -Infinity; + for (let c = 0; c < k; c++) { + let dot = 0; + let d2 = 0; + for (let i = 0; i < dim; i++) { + dot += v[i]! * centroids[c * dim + i]!; + d2 += (v[i]! - centroids[c * dim + i]!) ** 2; + } + const key = spherical ? dot : -d2; + if (key > bestKey) { + bestKey = key; + best = c; + } + } + expect(nearestCentroid(v, 0, centroids, k, dim, spherical)).toBe(best); + } + }); +}); diff --git a/src/core/kmeans.ts b/src/core/kmeans.ts new file mode 100644 index 0000000..75996b8 --- /dev/null +++ b/src/core/kmeans.ts @@ -0,0 +1,282 @@ +// Seeded k-means for quantvec's IVF coarse quantizer. +// +// Why this exists: IVF (inverted-file) search partitions the corpus into `k` +// coarse cells and probes only the cells nearest the query — sublinear scan cost +// on large corpora. The cells come from plain k-means: k-means++ seeding (Arthur +// & Vassilvitskii, "k-means++: The Advantages of Careful Seeding", SODA 2007) +// followed by Lloyd iterations (Lloyd, "Least Squares Quantization in PCM", +// IEEE Trans. IT 1982). Everything is driven by the caller's seeded RNG (see +// ./rng), so a given (data, seed) pair yields bit-identical centroids on every +// runtime — same determinism contract as the rotation. +// +// Two affinity modes: +// spherical=false — plain L2: assign by argmin ‖x − c‖², centroids are means. +// spherical=true — directions (cosine/dot corpora): assign by argmax ⟨x, c⟩ +// over unit centroids; after each mean step centroids are renormalized to +// unit length (a zero-norm mean keeps its previous direction). Callers +// pre-normalize the rows. + +import type { Rng } from './rng'; + +/** Discriminated, code-tagged error for the k-means module. */ +export class KMeansError extends Error { + readonly code: 'INVALID_K' | 'INVALID_DIM' | 'INVALID_LENGTH'; + constructor(code: KMeansError['code'], message: string) { + super(message); + this.name = 'KMeansError'; + this.code = code; + } +} + +/** Options for {@link kmeans}. */ +export interface KMeansOptions { + /** Number of centroids; integer in [2, m]. */ + k: number; + /** Coordinate dimension of each row. */ + dim: number; + /** Caller-seeded RNG — init is the only stochastic step. */ + rng: Rng; + /** Spherical mode (unit-direction rows, dot-product affinity). */ + spherical: boolean; + /** Lloyd iteration cap (default 25). */ + maxIterations?: number; +} + +/** Result of {@link kmeans}. */ +export interface KMeansResult { + /** Row-major k·dim centroids (unit-norm rows when spherical). */ + centroids: Float32Array; + /** Final assignment per input row, length m. */ + assignments: Int32Array; + /** Lloyd iterations actually run (≤ maxIterations). */ + iterations: number; +} + +/** Squared L2 distance between row `r` of `data` and row `c` of `centroids`. */ +function distSq( + data: Float32Array, + r: number, + centroids: Float32Array, + c: number, + dim: number, +): number { + let s = 0; + const rBase = r * dim; + const cBase = c * dim; + for (let i = 0; i < dim; i++) { + const d = data[rBase + i]! - centroids[cBase + i]!; + s += d * d; + } + return s; +} + +/** Dot product between row `r` of `data` and row `c` of `centroids`. */ +function rowDot( + data: Float32Array, + r: number, + centroids: Float32Array, + c: number, + dim: number, +): number { + let s = 0; + const rBase = r * dim; + const cBase = c * dim; + for (let i = 0; i < dim; i++) s += data[rBase + i]! * centroids[cBase + i]!; + return s; +} + +/** + * Nearest centroid of row `r` under the chosen affinity. Spherical ranks by + * dot product over unit centroids (argmax ⟨x, c⟩ — scale-invariant in x, so + * rows need not be re-normalized per call); L2 ranks by squared distance. + * First-best wins ties, keeping the result deterministic. + */ +export function nearestCentroid( + data: Float32Array, + r: number, + centroids: Float32Array, + k: number, + dim: number, + spherical: boolean, +): number { + let best = 0; + if (spherical) { + let bestDot = rowDot(data, r, centroids, 0, dim); + for (let c = 1; c < k; c++) { + const d = rowDot(data, r, centroids, c, dim); + if (d > bestDot) { + bestDot = d; + best = c; + } + } + } else { + let bestD = distSq(data, r, centroids, 0, dim); + for (let c = 1; c < k; c++) { + const d = distSq(data, r, centroids, c, dim); + if (d < bestD) { + bestD = d; + best = c; + } + } + } + return best; +} + +/** Normalize centroid row `c` to unit length; a zero-norm row is left unchanged. */ +function renormalizeRow(centroids: Float32Array, c: number, dim: number): void { + const base = c * dim; + let normSq = 0; + for (let i = 0; i < dim; i++) normSq += centroids[base + i]! * centroids[base + i]!; + if (normSq === 0) return; // keep the previous direction + const inv = 1 / Math.sqrt(normSq); + for (let i = 0; i < dim; i++) centroids[base + i] = centroids[base + i]! * inv; +} + +/** + * k-means++ seeding: the first centroid is drawn uniformly; each subsequent one + * is drawn with probability ∝ D²(x), the squared L2 distance to the nearest + * already-chosen centroid (one cumulative-sum pass + one uniform draw each). + * When every remaining D² is 0 (k duplicate rows), fall back to a uniform draw — + * a reachable degenerate input, so guarded and tested. + */ +function seedCentroids( + data: Float32Array, + m: number, + k: number, + dim: number, + rng: Rng, +): Float32Array { + const centroids = new Float32Array(k * dim); + const d2 = new Float64Array(m); + + const first = Math.min(m - 1, Math.floor(rng.nextFloat() * m)); + centroids.set(data.subarray(first * dim, first * dim + dim), 0); + for (let r = 0; r < m; r++) d2[r] = distSq(data, r, centroids, 0, dim); + + for (let c = 1; c < k; c++) { + let total = 0; + for (let r = 0; r < m; r++) total += d2[r]!; + let pick: number; + if (total > 0) { + // Inverse-CDF draw over the D² weights. + const target = rng.nextFloat() * total; + let acc = 0; + pick = m - 1; // float-rounding fallback: the last row + for (let r = 0; r < m; r++) { + acc += d2[r]!; + if (acc > target) { + pick = r; + break; + } + } + } else { + pick = Math.min(m - 1, Math.floor(rng.nextFloat() * m)); + } + centroids.set(data.subarray(pick * dim, pick * dim + dim), c * dim); + // Fold the new centroid into the D² table. + for (let r = 0; r < m; r++) { + const d = distSq(data, r, centroids, c, dim); + if (d < d2[r]!) d2[r] = d; + } + } + return centroids; +} + +/** + * Run seeded k-means++ + Lloyd over `m` row-major rows of `data` (length m·dim). + * + * Deterministic for a fixed (data, rng seed): k-means++ seeding is the only + * stochastic step, assignment ties break first-best, and the loop stops the + * first iteration that changes zero assignments (exact, float-tolerance-free) + * or at `maxIterations`. Empty clusters are repaired each round by re-seeding + * from the row currently farthest from its centroid (first-max wins). + * + * @throws {KMeansError} `'INVALID_DIM'` on a non-positive-integer dim; + * `'INVALID_LENGTH'` if `data.length !== m·dim`; `'INVALID_K'` unless + * `2 ≤ k ≤ m` (integer). + */ +export function kmeans(data: Float32Array, m: number, opts: KMeansOptions): KMeansResult { + const { k, dim, rng, spherical } = opts; + const maxIterations = opts.maxIterations ?? 25; + + if (!Number.isInteger(dim) || dim <= 0) { + throw new KMeansError('INVALID_DIM', `dim must be a positive integer, got ${dim}`); + } + if (!Number.isInteger(m) || m <= 0 || data.length !== m * dim) { + throw new KMeansError( + 'INVALID_LENGTH', + `data length ${data.length} must equal m·dim = ${m * dim}`, + ); + } + if (!Number.isInteger(k) || k < 2 || k > m) { + throw new KMeansError('INVALID_K', `k must be an integer in [2, m=${m}], got ${k}`); + } + + const centroids = seedCentroids(data, m, k, dim, rng); + if (spherical) for (let c = 0; c < k; c++) renormalizeRow(centroids, c, dim); + + const assignments = new Int32Array(m).fill(-1); + const counts = new Int32Array(k); + const sums = new Float64Array(k * dim); + + let iterations = 0; + while (iterations < maxIterations) { + iterations++; + + // ── Assign ──────────────────────────────────────────────────────────── + let changed = 0; + for (let r = 0; r < m; r++) { + const c = nearestCentroid(data, r, centroids, k, dim, spherical); + if (c !== assignments[r]) { + assignments[r] = c; + changed++; + } + } + + // ── Update means (f64 accumulation) ─────────────────────────────────── + counts.fill(0); + sums.fill(0); + for (let r = 0; r < m; r++) { + const c = assignments[r]!; + counts[c]!++; + const rBase = r * dim; + const cBase = c * dim; + for (let i = 0; i < dim; i++) sums[cBase + i] += data[rBase + i]!; + } + for (let c = 0; c < k; c++) { + if (counts[c] === 0) continue; // repaired below; keep the previous centroid + const inv = 1 / counts[c]!; + const base = c * dim; + for (let i = 0; i < dim; i++) centroids[base + i] = sums[base + i]! * inv; + if (spherical) renormalizeRow(centroids, c, dim); + } + + // ── Empty-cluster repair ────────────────────────────────────────────── + // Re-seed each empty centroid from the row farthest (L2) from its current + // centroid (first-max wins), claim that row, and let the next assign pass + // settle the rest. Bounded: each repair fills one empty cluster. + for (let c = 0; c < k; c++) { + if (counts[c]! > 0) continue; + let farRow = 0; + let farD = -1; + for (let r = 0; r < m; r++) { + if (counts[assignments[r]!]! <= 1) continue; // don't orphan a singleton + const d = distSq(data, r, centroids, assignments[r]!, dim); + if (d > farD) { + farD = d; + farRow = r; + } + } + counts[assignments[farRow]!]!--; + assignments[farRow] = c; + counts[c] = 1; + centroids.set(data.subarray(farRow * dim, farRow * dim + dim), c * dim); + if (spherical) renormalizeRow(centroids, c, dim); + changed++; + } + + if (changed === 0) break; + } + + return { centroids, assignments, iterations }; +} diff --git a/src/core/search.test.ts b/src/core/search.test.ts index 2f5656b..aa2ccc6 100644 --- a/src/core/search.test.ts +++ b/src/core/search.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { buildQueryLut, searchFlat, SearchError } from './search'; +import { buildQueryLut, searchFlat, searchSlots, SearchError } from './search'; import type { EncodedDb } from './search'; import { getCodebook } from './codebook'; import type { Bits } from './codebook'; @@ -452,3 +452,76 @@ describe('searchFlat — TQ+ calibration', () => { expect(res.scores.every((s) => Number.isFinite(s))).toBe(true); }); }); + +// ── searchSlots: the IVF probed-list subset scan ──────────────────────────────── + +describe('searchSlots', () => { + const dim = 64; + const n = 50; + const rng = createRng(135); + const vectors = randomVectors(n, dim, rng); + const db = buildDb(vectors, dim, 4); + const allSlots = Int32Array.from({ length: n }, (_, j) => j); + + it('over all slots equals searchFlat exactly (indices and scores)', () => { + for (const metric of ['dot', 'cosine', 'euclidean'] as const) { + const flat = searchFlat(db, vectors[2]!, 10, { metric }); + const sub = searchSlots(db, vectors[2]!, 10, allSlots, { metric }); + expect(Array.from(sub.indices)).toEqual(Array.from(flat.indices)); + expect(Array.from(sub.scores)).toEqual(Array.from(flat.scores)); + } + }); + + it('only returns members of the given subset', () => { + const subset = Int32Array.from([1, 5, 9, 13, 17, 21]); + const res = searchSlots(db, vectors[0]!, 4, subset, { metric: 'cosine' }); + const allowed = new Set(Array.from(subset)); + expect(res.indices.length).toBe(4); + for (const j of res.indices) expect(allowed.has(j)).toBe(true); + }); + + it('honors the full-length mask within the subset', () => { + const subset = Int32Array.from([0, 1, 2, 3]); + const mask = new Uint8Array(n).fill(1); + mask[1] = 0; + const res = searchSlots(db, vectors[1]!, 4, subset, { metric: 'dot', mask }); + expect(Array.from(res.indices)).not.toContain(1); + expect(res.indices.length).toBe(3); + }); + + it('empty slots yields an empty result; k > candidates yields a short result', () => { + const empty = searchSlots(db, vectors[0]!, 5, new Int32Array(0), { metric: 'dot' }); + expect(empty.indices.length).toBe(0); + const short = searchSlots(db, vectors[0]!, 5, Int32Array.from([3, 4]), { metric: 'dot' }); + expect(short.indices.length).toBe(2); + }); + + it('rejects an out-of-range slot with INVALID_SLOT', () => { + for (const bad of [-1, n]) { + let err: unknown; + try { + searchSlots(db, vectors[0]!, 2, Int32Array.from([0, bad]), { metric: 'dot' }); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(SearchError); + expect((err as SearchError).code).toBe('INVALID_SLOT'); + } + }); + + it('shares searchFlat validation: bad mask length and zero query throw identically', () => { + let err: unknown; + try { + searchSlots(db, vectors[0]!, 2, allSlots, { metric: 'dot', mask: new Uint8Array(3) }); + } catch (e) { + err = e; + } + expect((err as SearchError).code).toBe('INVALID_MASK'); + try { + searchSlots(db, new Float32Array(dim), 2, allSlots, { metric: 'dot' }); + } catch (e) { + err = e; + } + expect((err as SearchError).code).toBe('ZERO_QUERY'); + }); +}); diff --git a/src/core/search.ts b/src/core/search.ts index 6da2782..98e69b6 100644 --- a/src/core/search.ts +++ b/src/core/search.ts @@ -47,7 +47,8 @@ export class SearchError extends Error { | 'INVALID_LENGTH' | 'MISMATCH' | 'ZERO_QUERY' - | 'INVALID_MASK'; + | 'INVALID_MASK' + | 'INVALID_SLOT'; constructor(code: SearchError['code'], message: string) { super(message); this.name = 'SearchError'; @@ -170,13 +171,60 @@ export function buildQueryLut( * * @throws {SearchError} on any failed precondition above. */ -export function searchFlat( +/** + * Validate a query vector against `dim` — length, per-element finiteness, and a + * non-zero norm — throwing the same typed errors for the same bad inputs on every + * search path. Used by {@link searchFlat}/{@link searchSlots} (via the shared scan + * preamble) and by the index's IVF branch *before* centroid probing, so a malformed + * query never reaches the probe arithmetic. Returns the query norms. + * + * @throws {SearchError} `'INVALID_LENGTH'` on a wrong-length or non-finite query; + * `'ZERO_QUERY'` on a zero query (no direction). + */ +export function validateQuery(query: Float32Array, dim: number): QueryNorms { + if (query.length !== dim) { + throw new SearchError('INVALID_LENGTH', `query length ${query.length} != dim ${dim}`); + } + let qNormSq = 0; + for (let i = 0; i < dim; i++) { + const x = query[i]!; + if (!Number.isFinite(x)) { + throw new SearchError('INVALID_LENGTH', `query[${i}] must be finite, got ${x}`); + } + qNormSq += x * x; + } + if (qNormSq === 0) { + throw new SearchError('ZERO_QUERY', 'cannot search with a zero query (no direction)'); + } + return { qNorm: Math.sqrt(qNormSq), qNormSq }; +} + +/** Per-query state shared by {@link searchFlat} and {@link searchSlots}. */ +interface PreparedScan { + /** The per-query nibble LUT (dim·levels). */ + lut: Float32Array; + /** Per-query calibration bias ⟨q_rot, shift⟩ (0 when un-calibrated). */ + biasQ: number; + /** Query norms for ./metrics. */ + norms2: QueryNorms; + /** 2^bits. */ + levels: number; +} + +/** + * Shared scan preamble: validate every boundary (k, shapes, mask length, query + * finiteness/zero), rotate the query once, apply the calibration dual, and build + * the per-query LUT. Both scan entry points run identical validation, so they + * throw identical typed errors for identical bad inputs. + * + * @throws {SearchError} on any failed precondition (see {@link searchFlat}). + */ +function prepareScan( db: EncodedDb, query: Float32Array, k: number, opts: SearchOptions, - computeScores?: (lut: Float32Array, out: Float64Array) => void, -): SearchResult { +): PreparedScan { const { n, dim, bits, codes, scales, norms, centroids, rotation } = db; const levels = 1 << bits; @@ -187,9 +235,7 @@ export function searchFlat( if (rotation.dim !== dim) { throw new SearchError('MISMATCH', `rotation.dim ${rotation.dim} != dim ${dim}`); } - if (query.length !== dim) { - throw new SearchError('INVALID_LENGTH', `query length ${query.length} != dim ${dim}`); - } + const norms2 = validateQuery(query, dim); if (codes.length !== n * dim) { throw new SearchError('MISMATCH', `codes length ${codes.length} != n·dim ${n * dim}`); } @@ -207,21 +253,6 @@ export function searchFlat( throw new SearchError('INVALID_MASK', `mask length ${mask.length} != n ${n}`); } - // ── Query norms (also rejects a non-finite / zero query) ───────────────── - let qNormSq = 0; - for (let i = 0; i < dim; i++) { - const x = query[i]!; - if (!Number.isFinite(x)) { - throw new SearchError('INVALID_LENGTH', `query[${i}] must be finite, got ${x}`); - } - qNormSq += x * x; - } - if (qNormSq === 0) { - throw new SearchError('ZERO_QUERY', 'cannot search with a zero query (no direction)'); - } - const qNorm = Math.sqrt(qNormSq); - const norms2: QueryNorms = { qNorm, qNormSq }; - // ── Rotate the query once, then build the shared LUT ───────────────────── // With TQ+ calibration the codes hold calibrated coordinates, so we score against // q_calib = q_rot / scale and subtract the per-query bias ⟨q_rot, shift⟩ from every @@ -242,7 +273,19 @@ export function searchFlat( } const lut = buildQueryLut(lutQuery, centroids, dim, levels); - const { metric } = opts; + return { lut, biasQ, norms2, levels }; +} + +export function searchFlat( + db: EncodedDb, + query: Float32Array, + k: number, + opts: SearchOptions, + computeScores?: (lut: Float32Array, out: Float64Array) => void, +): SearchResult { + const { n, dim, codes, scales, norms } = db; + const { lut, biasQ, norms2, levels } = prepareScan(db, query, k, opts); + const { mask, metric } = opts; const top = new TopK(k); // ── The flat scan ──────────────────────────────────────────────────────── @@ -277,3 +320,49 @@ export function searchFlat( for (let i = 0; i < indices.length; i++) scores[i] = mapKeyToValue(metric, keys[i]!); return { indices, scores }; } + +/** + * Subset scan: score ONLY the given `slots` (database row indices) — the IVF + * probed-posting-list scan. Identical query preparation, validation, and error + * semantics as {@link searchFlat} (same typed errors for the same bad inputs), + * plus `'INVALID_SLOT'` for a slot outside [0, n). `opts.mask`, when given, is + * the full n-length allowlist indexed by slot — a probed slot the mask excludes + * is skipped, exactly like the flat scan. + * + * Returns up to k best, best-first; fewer (possibly zero) when the slots/mask + * yield fewer candidates. A slot listed twice would be scored twice (the heap + * would then hold duplicates) — callers pass disjoint posting lists, and the + * IVF bookkeeping guarantees disjointness, so this is not guarded. + * + * @throws {SearchError} on any failed precondition above. + */ +export function searchSlots( + db: EncodedDb, + query: Float32Array, + k: number, + slots: Int32Array, + opts: SearchOptions, +): SearchResult { + const { n, dim, codes, scales, norms } = db; + const { lut, biasQ, norms2, levels } = prepareScan(db, query, k, opts); + const { mask, metric } = opts; + + const top = new TopK(k); + for (let t = 0; t < slots.length; t++) { + const j = slots[t]!; + if (!Number.isInteger(j) || j < 0 || j >= n) { + throw new SearchError('INVALID_SLOT', `slot ${j} out of range [0, ${n})`); + } + if (mask !== undefined && !mask[j]) continue; + const base = j * dim; + let s = 0; + for (let i = 0; i < dim; i++) s += lut[i * levels + codes[base + i]!]!; + const { rankKey } = scoreMetric(metric, s - biasQ, scales[j]!, norms[j]!, norms2); + top.add(rankKey, j); + } + + const { indices, scores: keys } = top.result(); + const scores = new Float32Array(indices.length); + for (let i = 0; i < indices.length; i++) scores[i] = mapKeyToValue(metric, keys[i]!); + return { indices, scores }; +} diff --git a/src/ergonomic/collection.test.ts b/src/ergonomic/collection.test.ts index da134b2..3599d5a 100644 --- a/src/ergonomic/collection.test.ts +++ b/src/ergonomic/collection.test.ts @@ -163,3 +163,46 @@ describe('Collection — mutation', () => { expect(c.get(2)).toEqual({ tag: 'blog', year: 2023 }); }); }); + +describe('Collection — IVF config passthrough', () => { + function ivfCollection(): Collection { + const c = createCollection({ + vectors: { size: 8, distance: 'cosine' }, + quantization: { bits: 4 }, + ivf: { nlist: 2, nprobe: 1 }, + }); + // Two clusters along orthogonal directions; 4 points ≥ nlist trains the quantizer. + c.upsert([ + { id: 1, vector: V[0]!, payload: { tag: 'a', year: 2020 } }, + { id: 2, vector: V[1]!, payload: { tag: 'b', year: 2021 } }, + { id: 3, vector: V[2]!, payload: { tag: 'a', year: 2022 } }, + { id: 4, vector: V[3]!, payload: { tag: 'b', year: 2023 } }, + ]); + return c; + } + + it('first upsert trains the coarse quantizer; search works with nprobe override', () => { + const c = ivfCollection(); + const hits = c.search(V[2]!, { limit: 2, nprobe: 2 }); + expect(hits[0]!.id).toBe(3); + expect(hits[0]!.payload).toEqual({ tag: 'a', year: 2022 }); + }); + + it('filters and deletes compose with IVF', () => { + const c = ivfCollection(); + const hits = c.search(V[0]!, { + limit: 4, + nprobe: 2, + filter: { must: [{ key: 'tag', match: { value: 'a' } }] }, + }); + expect(hits.every((h) => h.payload!.tag === 'a')).toBe(true); + c.delete([1, 3]); + expect(c.size).toBe(2); + expect( + c + .search(V[1]!, { limit: 4, nprobe: 2 }) + .map((h) => h.id) + .sort(), + ).toEqual([2, 4]); + }); +}); diff --git a/src/ergonomic/collection.ts b/src/ergonomic/collection.ts index ea1aaeb..0e40129 100644 --- a/src/ergonomic/collection.ts +++ b/src/ergonomic/collection.ts @@ -31,6 +31,7 @@ export class Collection

{ }; if (config.seed !== undefined) opts.seed = config.seed; if (config.calibrate !== undefined) opts.calibrate = config.calibrate; + if (config.ivf !== undefined) opts.ivf = config.ivf; this.#index = new IdMapIndex(opts); this.#payloads = new Map(); } @@ -101,6 +102,7 @@ export class Collection

{ const withPayload = params.withPayload !== false; const opts: IdMapSearchOptions = {}; + if (params.nprobe !== undefined) opts.nprobe = params.nprobe; if (params.filter !== undefined) { const predicate = compileFilter(params.filter); opts.filter = (id) => predicate(id, this.#payloads.get(id)); diff --git a/src/ergonomic/types.ts b/src/ergonomic/types.ts index 0a5306f..606688c 100644 --- a/src/ergonomic/types.ts +++ b/src/ergonomic/types.ts @@ -33,6 +33,12 @@ export interface CollectionConfig { seed?: number; /** Enable TQ+ calibration (default false; data-dependent — see TurboQuantIndex). */ calibrate?: boolean; + /** + * Enable IVF coarse-quantized search (default off): `nlist` cells trained from the + * first upsert of ≥ nlist points and frozen; queries probe `nprobe` cells (default + * ⌈nlist/8⌉, overridable per search). See TurboQuantIndex for the full contract. + */ + ivf?: { nlist: number; nprobe?: number }; } /** Per-search parameters. */ @@ -43,6 +49,8 @@ export interface SearchParams { filter?: Filter; /** Include each hit's payload (default true). */ withPayload?: boolean; + /** Override the IVF probe breadth for this query (ignored when IVF is not active). */ + nprobe?: number; } // ── Filter DSL (qdrant-inspired) ─────────────────────────────────────────────── diff --git a/src/index.ts b/src/index.ts index 87e139f..60c19df 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,7 +15,11 @@ export const VERSION: string = __QUANTVEC_VERSION__; // ── Core positional index ───────────────────────────────────────────────────── export { TurboQuantIndex, IndexError } from './index/turboquant-index'; -export type { TurboQuantIndexOptions, IndexSearchOptions } from './index/turboquant-index'; +export type { + TurboQuantIndexOptions, + IndexSearchOptions, + IvfOptions, +} from './index/turboquant-index'; // ── Stable id-keyed index ───────────────────────────────────────────────────── export { IdMapIndex, IdMapError } from './index/id-map-index'; diff --git a/src/index/coarse.test.ts b/src/index/coarse.test.ts new file mode 100644 index 0000000..1acbff5 --- /dev/null +++ b/src/index/coarse.test.ts @@ -0,0 +1,187 @@ +import { describe, expect, it } from 'vitest'; +import { CoarseQuantizer, defaultNprobe, IVF_TRAIN_SAMPLE_PER_LIST } from './coarse'; +import { createRng } from '../core/rng'; + +const DIM = 8; + +function gaussianVecs(n: number, seed: number, dim = DIM): Float32Array[] { + const rng = createRng(seed); + return Array.from({ length: n }, () => { + const v = new Float32Array(dim); + for (let i = 0; i < dim; i++) v[i] = rng.nextGaussian(); + return v; + }); +} + +/** Assert the slot-membership invariants over every live slot. */ +function assertInvariants(cq: CoarseQuantizer, liveSlots: number): void { + const snapshot = cq.listForSlotSnapshot(liveSlots); + // probe with nprobe = nlist returns every live slot exactly once. + const probeQuery = new Float32Array(DIM).fill(1); + const all = cq.probe(probeQuery, cq.nlist); + expect(all.length).toBe(liveSlots); + expect(new Set(Array.from(all)).size).toBe(liveSlots); + for (const s of all) { + expect(s).toBeGreaterThanOrEqual(0); + expect(s).toBeLessThan(liveSlots); + expect(snapshot[s]).toBeGreaterThanOrEqual(0); + expect(snapshot[s]).toBeLessThan(cq.nlist); + } +} + +describe('defaultNprobe', () => { + it('is 1/8 of nlist, at least 1', () => { + expect(defaultNprobe(1)).toBe(1); + expect(defaultNprobe(8)).toBe(1); + expect(defaultNprobe(64)).toBe(8); + expect(defaultNprobe(100)).toBe(13); + }); +}); + +describe('CoarseQuantizer — training', () => { + it('is deterministic for the same vectors and seed', () => { + const vecs = gaussianVecs(80, 1); + const a = CoarseQuantizer.train(vecs, 4, 1, 'cosine', DIM, 42); + const b = CoarseQuantizer.train(vecs, 4, 1, 'cosine', DIM, 42); + expect(Array.from(a.centroids)).toEqual(Array.from(b.centroids)); + }); + + it('spherical metrics produce unit-norm centroids; euclidean does not renormalize', () => { + const vecs = gaussianVecs(60, 2).map((v) => { + for (let i = 0; i < DIM; i++) v[i] = v[i]! * 3 + 5; // off-origin, non-unit + return v; + }); + const sph = CoarseQuantizer.train(vecs, 3, 1, 'dot', DIM, 7); + for (let c = 0; c < 3; c++) { + let n = 0; + for (let i = 0; i < DIM; i++) n += sph.centroids[c * DIM + i]! ** 2; + expect(Math.sqrt(n)).toBeCloseTo(1, 5); + } + const l2 = CoarseQuantizer.train(vecs, 3, 1, 'euclidean', DIM, 7); + let off = 0; + for (let c = 0; c < 3; c++) { + let n = 0; + for (let i = 0; i < DIM; i++) n += l2.centroids[c * DIM + i]! ** 2; + if (Math.abs(Math.sqrt(n) - 1) > 0.1) off++; + } + expect(off).toBeGreaterThan(0); // means sit near the data, far from unit norm + }); + + it('caps the training sample (IVF_TRAIN_SAMPLE_PER_LIST) without error on big batches', () => { + const nlist = 2; + const vecs = gaussianVecs(IVF_TRAIN_SAMPLE_PER_LIST * nlist + 50, 3); + const cq = CoarseQuantizer.train(vecs, nlist, 1, 'cosine', DIM, 5); + expect(cq.centroids.length).toBe(nlist * DIM); + }); + + it('skips zero rows in training', () => { + const vecs = gaussianVecs(40, 4); + vecs[3] = new Float32Array(DIM); // zero row + const cq = CoarseQuantizer.train(vecs, 4, 1, 'cosine', DIM, 9); + for (const x of cq.centroids) expect(Number.isFinite(x)).toBe(true); + }); +}); + +describe('CoarseQuantizer — assignment & probing', () => { + it('assign matches brute-force nearest centroid for both affinities', () => { + for (const metric of ['cosine', 'euclidean'] as const) { + const vecs = gaussianVecs(64, 5); + const cq = CoarseQuantizer.train(vecs, 8, 2, metric, DIM, 11); + const rng = createRng(6); + for (let t = 0; t < 20; t++) { + const v = new Float32Array(DIM); + for (let i = 0; i < DIM; i++) v[i] = rng.nextGaussian(); + let best = 0; + let bestKey = -Infinity; + for (let c = 0; c < 8; c++) { + let dot = 0; + let d2 = 0; + for (let i = 0; i < DIM; i++) { + dot += v[i]! * cq.centroids[c * DIM + i]!; + d2 += (v[i]! - cq.centroids[c * DIM + i]!) ** 2; + } + const key = metric === 'euclidean' ? -d2 : dot; + if (key > bestKey) { + bestKey = key; + best = c; + } + } + expect(cq.assign(v)).toBe(best); + } + } + }); + + it('probe(nprobe = nlist) returns all live slots; smaller nprobe returns a subset', () => { + const vecs = gaussianVecs(100, 7); + const cq = CoarseQuantizer.train(vecs, 8, 2, 'cosine', DIM, 13); + for (let s = 0; s < vecs.length; s++) cq.addSlot(s, vecs[s]!); + const all = cq.probe(vecs[0]!, 8); + expect(all.length).toBe(100); + const some = cq.probe(vecs[0]!, 2); + expect(some.length).toBeLessThan(100); + const allSet = new Set(Array.from(all)); + for (const s of some) expect(allSet.has(s)).toBe(true); + }); +}); + +describe('CoarseQuantizer — posting-list bookkeeping', () => { + function build(n: number, seed: number): { cq: CoarseQuantizer; vecs: Float32Array[] } { + const vecs = gaussianVecs(n, seed); + const cq = CoarseQuantizer.train(vecs, 4, 1, 'cosine', DIM, 17); + for (let s = 0; s < n; s++) cq.addSlot(s, vecs[s]!); + return { cq, vecs }; + } + + it('swapRemove of the last slot (i === last)', () => { + const { cq } = build(10, 8); + cq.swapRemove(9, 9); + assertInvariants(cq, 9); + }); + + it('swapRemove of an interior slot renumbers the moved last slot', () => { + const { cq } = build(10, 9); + cq.swapRemove(3, 9); + assertInvariants(cq, 9); + }); + + it('invariant fuzz: random interleaved adds and removes', () => { + const rng = createRng(99); + const vecs = gaussianVecs(400, 10); + const cq = CoarseQuantizer.train(vecs.slice(0, 50), 4, 1, 'cosine', DIM, 19); + let n = 0; + let next = 0; + for (let op = 0; op < 300; op++) { + if (n === 0 || (rng.nextFloat() < 0.6 && next < vecs.length)) { + cq.addSlot(n, vecs[next]!); + next++; + n++; + } else { + const i = Math.min(n - 1, Math.floor(rng.nextFloat() * n)); + cq.swapRemove(i, n - 1); + n--; + } + } + assertInvariants(cq, n); + }); + + it('clear empties postings but keeps centroids; adds work after clear', () => { + const { cq, vecs } = build(20, 11); + const before = Array.from(cq.centroids); + cq.clear(); + expect(cq.probe(vecs[0]!, 4).length).toBe(0); + expect(Array.from(cq.centroids)).toEqual(before); + cq.addSlot(0, vecs[5]!); + assertInvariants(cq, 1); + }); + + it('fromState round-trips the snapshot', () => { + const { cq } = build(30, 12); + const snapshot = cq.listForSlotSnapshot(30); + const restored = CoarseQuantizer.fromState(cq.centroids, snapshot, 4, 1, 'cosine', DIM); + expect(Array.from(restored.listForSlotSnapshot(30))).toEqual(Array.from(snapshot)); + assertInvariants(restored, 30); + // Mutations on the restored quantizer keep the invariants. + restored.swapRemove(2, 29); + assertInvariants(restored, 29); + }); +}); diff --git a/src/index/coarse.ts b/src/index/coarse.ts new file mode 100644 index 0000000..03b0dd9 --- /dev/null +++ b/src/index/coarse.ts @@ -0,0 +1,265 @@ +// IVF coarse quantizer — the cell structure behind TurboQuantIndex's `ivf` mode. +// +// Why this exists: a flat scan is O(n) per query. The inverted-file (IVF) idea +// partitions the corpus into `nlist` cells around k-means centroids; a query +// ranks the centroids, probes only the `nprobe` nearest cells, and scans just +// those cells' members — sublinear work at large n. This class owns the cell +// state: the trained centroids (frozen, like the rotation and TQ+ calibration) +// and the posting lists that map cells → live slots, kept in lockstep with the +// index's swap-remove storage. +// +// Affinity is metric-consistent with the index: cosine/dot rank cells by +// argmax ⟨v, c⟩ over unit centroids (spherical k-means); euclidean ranks by +// argmin ‖v − c‖². The quantized scan over the probed slots is ./../core/search +// `searchSlots` — the same exact kernel as the flat scan, visited in posting-list +// order. At nprobe = nlist the index routes to the flat scan itself (canonical +// slot order), so that oracle case is EXACTLY the flat scan's result even when +// duplicate vectors tie at the k boundary. + +import { kmeans, nearestCentroid } from '../core/kmeans'; +import type { Distance } from '../core/metrics'; +import { createRng } from '../core/rng'; +import { TopK } from '../core/topk'; + +/** Training sample cap: at most this many vectors per cell are fed to k-means. */ +export const IVF_TRAIN_SAMPLE_PER_LIST = 64; + +/** Default probe breadth: 1/8 of the cells, at least one. */ +export function defaultNprobe(nlist: number): number { + return Math.max(1, Math.ceil(nlist / 8)); +} + +/** + * The trained cell structure: centroids plus posting-list bookkeeping. + * + * Slot membership invariants (maintained by addSlot/swapRemove/clear, fuzzed in + * coarse.test.ts): for every live slot s, + * postings[listForSlot[s]][posForSlot[s]] === s + * and Σ_l |postings[l]| equals the number of live slots. + */ +export class CoarseQuantizer { + readonly nlist: number; + readonly dim: number; + readonly metric: Distance; + /** Resolved default probe breadth (validated by the index). */ + readonly defaultNprobe: number; + /** Row-major nlist·dim centroids (unit rows for cosine/dot). */ + readonly centroids: Float32Array; + + /** Cell members: slot ids per list (push/pop O(1)). */ + #postings: number[][]; + /** Slot → owning list (grown by doubling alongside the index's capacity). */ + #listForSlot: Int32Array; + /** Slot → position inside its posting list (for O(1) removal). */ + #posForSlot: Int32Array; + + private constructor( + centroids: Float32Array, + nlist: number, + nprobe: number, + metric: Distance, + dim: number, + ) { + this.nlist = nlist; + this.dim = dim; + this.metric = metric; + this.defaultNprobe = nprobe; + this.centroids = centroids; + this.#postings = Array.from({ length: nlist }, () => []); + this.#listForSlot = new Int32Array(0); + this.#posForSlot = new Int32Array(0); + } + + /** Whether cells use the spherical (dot) affinity. */ + get #spherical(): boolean { + return this.metric !== 'euclidean'; + } + + /** + * Train centroids from the first batch. The caller (TurboQuantIndex) + * guarantees `vecs.length >= nlist` and per-vector length === dim. Training + * samples at most {@link IVF_TRAIN_SAMPLE_PER_LIST}·nlist rows (partial + * Fisher–Yates over a domain-separated RNG derived from the index seed), so + * the k-means cost is bounded regardless of the first batch's size. Zero-norm + * rows are skipped the same way the calibration fitter skips them — encode + * rejects them moments later anyway. + */ + static train( + vecs: readonly Float32Array[], + nlist: number, + nprobe: number, + metric: Distance, + dim: number, + seed: number, + ): CoarseQuantizer { + // Domain-separate the k-means RNG stream from the rotation's stream derived + // from the same seed: XOR in "IVF1" (ASCII). createRng masks bigint seeds to + // 64 bits itself, and the separator fits in the mask, so XOR-then-mask here + // equals mask-then-XOR — no explicit masking needed. + const rng = createRng(BigInt(Math.trunc(seed)) ^ 0x49564631n); + const m = vecs.length; + const sampleSize = Math.min(m, IVF_TRAIN_SAMPLE_PER_LIST * nlist); + + // Partial Fisher–Yates: pick `sampleSize` distinct row indices deterministically. + const order = new Int32Array(m); + for (let i = 0; i < m; i++) order[i] = i; + for (let i = 0; i < sampleSize; i++) { + const j = i + Math.min(m - 1 - i, Math.floor(rng.nextFloat() * (m - i))); + const t = order[i]!; + order[i] = order[j]!; + order[j] = t; + } + + const spherical = metric !== 'euclidean'; + const data = new Float32Array(sampleSize * dim); + let rows = 0; + for (let s = 0; s < sampleSize; s++) { + const v = vecs[order[s]!]!; + let normSq = 0; + for (let i = 0; i < dim; i++) normSq += v[i]! * v[i]!; + if (normSq === 0) continue; // zero rows can't be normalized; encode will reject them + const base = rows * dim; + if (spherical) { + const inv = 1 / Math.sqrt(normSq); + for (let i = 0; i < dim; i++) data[base + i] = v[i]! * inv; + } else { + data.set(v, base); + } + rows++; + } + + const { centroids } = kmeans(data.subarray(0, rows * dim), rows, { + k: nlist, + dim, + rng, + spherical, + }); + return new CoarseQuantizer(centroids, nlist, nprobe, metric, dim); + } + + /** Rebuild from deserialized state; postings are reconstructed from `listForSlot`. */ + static fromState( + centroids: Float32Array, + listForSlot: Int32Array, + nlist: number, + nprobe: number, + metric: Distance, + dim: number, + ): CoarseQuantizer { + const cq = new CoarseQuantizer(centroids, nlist, nprobe, metric, dim); + for (let slot = 0; slot < listForSlot.length; slot++) { + cq.#growTo(slot + 1); + const list = listForSlot[slot]!; + cq.#listForSlot[slot] = list; + cq.#posForSlot[slot] = cq.#postings[list]!.length; + cq.#postings[list]!.push(slot); + } + return cq; + } + + /** Nearest cell for `vec` under the index's metric (raw, unrotated vector). */ + assign(vec: Float32Array): number { + return nearestCentroid(vec, 0, this.centroids, this.nlist, this.dim, this.#spherical); + } + + /** Assign `vec` and record `slot` (the index's next row) in that cell's list. */ + addSlot(slot: number, vec: Float32Array): void { + const list = this.assign(vec); + this.#growTo(slot + 1); + this.#listForSlot[slot] = list; + this.#posForSlot[slot] = this.#postings[list]!.length; + this.#postings[list]!.push(slot); + } + + /** + * Mirror the index's swap-remove of slot `i` (the row at slot `last = n−1` + * moved into the gap). Two memberships are patched, in this order: + * + * A) drop slot i from its own list by swap-pop: pop the list's tail; if the + * tail wasn't i itself, the tail fills i's hole (and its posForSlot moves). + * B) when i !== last, renumber `last` → `i` in last's list — reading last's + * position AFTER step A, because step A may have just moved it. + */ + swapRemove(i: number, last: number): void { + // Step A — remove slot i from its posting list. + const li = this.#listForSlot[i]!; + const p = this.#posForSlot[i]!; + const listI = this.#postings[li]!; + const tail = listI.pop()!; + if (tail !== i) { + listI[p] = tail; + this.#posForSlot[tail] = p; + } + + // Step B — slot `last` now lives at slot `i`. + if (i !== last) { + const lj = this.#listForSlot[last]!; + const pj = this.#posForSlot[last]!; + this.#postings[lj]![pj] = i; + this.#listForSlot[i] = lj; + this.#posForSlot[i] = pj; + } + } + + /** Empty every posting list. The trained centroids survive — clearing data + * does not unfreeze the training decision (same contract as calibration). */ + clear(): void { + for (const list of this.#postings) list.length = 0; + } + + /** Snapshot of slot → list for the first `n` slots (the serialized form). */ + listForSlotSnapshot(n: number): Int32Array { + return this.#listForSlot.slice(0, n); + } + + /** + * Rank all cells against `query` (metric-consistent, higher-is-better) and + * concatenate the top-`nprobe` cells' slots into one array for `searchSlots`. + */ + probe(query: Float32Array, nprobe: number): Int32Array { + const { nlist, dim, centroids } = this; + const spherical = this.#spherical; + const top = new TopK(nprobe); + for (let c = 0; c < nlist; c++) { + let key: number; + if (spherical) { + let dot = 0; + const base = c * dim; + for (let i = 0; i < dim; i++) dot += query[i]! * centroids[base + i]!; + key = dot; + } else { + let d2 = 0; + const base = c * dim; + for (let i = 0; i < dim; i++) { + const d = query[i]! - centroids[base + i]!; + d2 += d * d; + } + key = -d2; + } + top.add(key, c); + } + + const { indices: lists } = top.result(); + let total = 0; + for (const l of lists) total += this.#postings[l]!.length; + const slots = new Int32Array(total); + let off = 0; + for (const l of lists) { + const list = this.#postings[l]!; + for (let t = 0; t < list.length; t++) slots[off++] = list[t]!; + } + return slots; + } + + /** Grow the slot-indexed arrays to hold at least `needed` slots (doubling). */ + #growTo(needed: number): void { + if (needed <= this.#listForSlot.length) return; + const cap = Math.max(8, needed, this.#listForSlot.length * 2); + const nl = new Int32Array(cap); + nl.set(this.#listForSlot); + this.#listForSlot = nl; + const np = new Int32Array(cap); + np.set(this.#posForSlot); + this.#posForSlot = np; + } +} diff --git a/src/index/id-map-index.test.ts b/src/index/id-map-index.test.ts index 927e7d0..0fb8eae 100644 --- a/src/index/id-map-index.test.ts +++ b/src/index/id-map-index.test.ts @@ -288,3 +288,56 @@ describe('IdMapIndex — TQ+ calibration', () => { expect(restored.has(0)).toBe(true); }); }); + +describe('IdMapIndex — IVF passthrough', () => { + const VDIM = 16; + + function vecsAround(center: number, n: number, seed: number): Float32Array[] { + const rng = createRng(seed); + return Array.from({ length: n }, () => { + const v = new Float32Array(VDIM); + for (let i = 0; i < VDIM; i++) v[i] = center + rng.nextGaussian(); + return v; + }); + } + + it('trains from the first addWithIds batch and searches by id', () => { + const data = [...vecsAround(10, 20, 1), ...vecsAround(-10, 20, 2)]; + const ids = data.map((_, i) => 1000 + i); + const idx = new IdMapIndex({ dim: VDIM, ivf: { nlist: 2 } }); + idx.addWithIds(ids, data); + expect(idx.ivfActive).toBe(true); + const res = idx.search(data[5]!, 3, { nprobe: 2 }); + expect(res.ids).toContain(1005); + }); + + it('remove keeps parity with a flat twin; round-trip preserves ivf + id mapping', () => { + const data = [...vecsAround(10, 15, 3), ...vecsAround(-10, 15, 4)]; + const ids = data.map((_, i) => `p${i}`); + const ivf = new IdMapIndex({ dim: VDIM, ivf: { nlist: 2 } }); + const flat = new IdMapIndex({ dim: VDIM, wasm: false }); + ivf.addWithIds(ids, data); + flat.addWithIds(ids, data); + for (const victim of ['p3', 'p17', 'p0']) { + ivf.remove(victim); + flat.remove(victim); + } + const a = flat.search(data[5]!, 5); + const b = ivf.search(data[5]!, 5, { nprobe: 2 }); + expect(b.ids).toEqual(a.ids); + + const restored = IdMapIndex.fromBytes(ivf.toBytes()); + expect(restored.ivfActive).toBe(true); + expect(restored.search(data[5]!, 5, { nprobe: 2 }).ids).toEqual(b.ids); + }); + + it('filter predicates compose with the probed-cell scan', () => { + const data = [...vecsAround(10, 20, 5), ...vecsAround(-10, 20, 6)]; + const ids = data.map((_, i) => i); + const idx = new IdMapIndex({ dim: VDIM, ivf: { nlist: 2 } }); + idx.addWithIds(ids, data); + const res = idx.search(data[0]!, 10, { nprobe: 2, filter: (id) => id % 2 === 0 }); + expect(res.ids.every((id) => id % 2 === 0)).toBe(true); + expect(res.ids.length).toBeGreaterThan(0); + }); +}); diff --git a/src/index/id-map-index.ts b/src/index/id-map-index.ts index 1ad43fd..dbd3c42 100644 --- a/src/index/id-map-index.ts +++ b/src/index/id-map-index.ts @@ -51,6 +51,11 @@ export interface IdMapSearchOptions { * over large indexes the qdrant-style filter DSL (ergonomic layer) will be cheaper. */ filter?: (id: Id) => boolean; + /** + * Override the IVF probe breadth for this query (integer in [1, nlist]). Ignored + * when IVF is not active. See {@link TurboQuantIndexOptions.ivf}. + */ + nprobe?: number; } /** Result of {@link IdMapIndex.search}: external ids best-first plus aligned metric values. */ @@ -132,6 +137,11 @@ export class IdMapIndex { return this.#index.calibrated; } + /** Whether an IVF coarse quantizer was trained and is in effect. */ + get ivfActive(): boolean { + return this.#index.ivfActive; + } + /** Whether `id` is currently present. */ has(id: Id): boolean { return this.#slotForId.has(id); @@ -219,6 +229,7 @@ export class IdMapIndex { validateVectorBatch(vecArr); this.#index.fitCalibrationFromBatch(vecArr); + this.#index.trainIvfFromBatch(vecArr); for (let j = 0; j < m; j++) { this.#index.addOne(vecArr[j]!); @@ -244,6 +255,7 @@ export class IdMapIndex { } const innerOpts: IndexSearchOptions = {}; if (opts.metric !== undefined) innerOpts.metric = opts.metric; + if (opts.nprobe !== undefined) innerOpts.nprobe = opts.nprobe; if (opts.filter !== undefined) { const filter = opts.filter; const mask = new Uint8Array(this.#idForSlot.length); diff --git a/src/index/turboquant-index.test.ts b/src/index/turboquant-index.test.ts index 95ed016..d2adf5d 100644 --- a/src/index/turboquant-index.test.ts +++ b/src/index/turboquant-index.test.ts @@ -576,3 +576,312 @@ describe('TurboQuantIndex — FastScan path (v128 blocked-nibble + exact rescore expect(totalHits / (queries.length * k)).toBeGreaterThan(0.8); }); }); + +describe('TurboQuantIndex — IVF coarse quantizer', () => { + const IDIM = 32; + + /** Gaussian mixture: `clusters` well-separated centers, `per` points each. */ + function clusteredVecs(clusters: number, per: number, seed: number): Float32Array[] { + const rng = createRng(seed); + const centers = Array.from({ length: clusters }, () => { + const c = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) c[i] = rng.nextGaussian() * 10; + return c; + }); + const out: Float32Array[] = []; + for (let b = 0; b < clusters; b++) { + for (let j = 0; j < per; j++) { + const v = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) v[i] = centers[b]![i]! + rng.nextGaussian(); + out.push(v); + } + } + return out; + } + + it('validates nlist and nprobe at construction', () => { + for (const nlist of [1, 1.5, 0, 1 << 23]) { + let err: unknown; + try { + new TurboQuantIndex({ dim: IDIM, ivf: { nlist } }); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(IndexError); + expect((err as IndexError).code).toBe('INVALID_NLIST'); + } + for (const nprobe of [0, 5, 2.5]) { + let err: unknown; + try { + new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4, nprobe } }); + } catch (e) { + err = e; + } + expect((err as IndexError).code).toBe('INVALID_NPROBE'); + } + }); + + it('trains on the first batch when it has ≥ nlist vectors, and freezes flat otherwise', () => { + const data = clusteredVecs(4, 10, 1); + const trained = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + trained.add(data); + expect(trained.ivfActive).toBe(true); + + const flat = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 8 } }); + flat.add(data.slice(0, 7)); // 7 < nlist = 8 → flat forever + expect(flat.ivfActive).toBe(false); + flat.add(data); // a later big add must NOT retrain (decision frozen) + expect(flat.ivfActive).toBe(false); + }); + + it('nprobe = nlist reproduces the flat scan exactly (indices and scores)', () => { + const data = clusteredVecs(8, 25, 2); + const queries = clusteredVecs(8, 1, 3); + const flat = new TurboQuantIndex({ dim: IDIM, wasm: false }); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 8 } }); + flat.add(data); + ivf.add(data); + expect(ivf.ivfActive).toBe(true); + for (const metric of ['dot', 'cosine', 'euclidean'] as const) { + for (const q of queries) { + const a = flat.search(q, 10, { metric }); + const b = ivf.search(q, 10, { metric, nprobe: 8 }); + expect(Array.from(b.indices)).toEqual(Array.from(a.indices)); + expect(Array.from(b.scores)).toEqual(Array.from(a.scores)); + } + } + }); + + it('achieves high recall at nprobe ≪ nlist on clustered data', () => { + const data = clusteredVecs(16, 30, 4); + const queries = clusteredVecs(16, 1, 5); + const flat = new TurboQuantIndex({ dim: IDIM, wasm: false }); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 16, nprobe: 4 } }); + flat.add(data); + ivf.add(data); + const k = 10; + let hits = 0; + for (const q of queries) { + const exact = new Set(Array.from(flat.search(q, k).indices)); + for (const j of ivf.search(q, k).indices) if (exact.has(j)) hits++; + } + expect(hits / (queries.length * k)).toBeGreaterThan(0.8); + }); + + it('keeps remove parity with a flat twin (interleaved adds and removes)', () => { + const data = clusteredVecs(4, 30, 6); + const flat = new TurboQuantIndex({ dim: IDIM, wasm: false }); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + flat.add(data.slice(0, 80)); + ivf.add(data.slice(0, 80)); + const rng = createRng(7); + let n = 80; + let next = 80; + for (let op = 0; op < 60; op++) { + if (rng.nextFloat() < 0.5 && next < data.length) { + flat.addOne(data[next]!); + ivf.addOne(data[next]!); + next++; + n++; + } else if (n > 1) { + const i = Math.min(n - 1, Math.floor(rng.nextFloat() * n)); + flat.swapRemove(i); + ivf.swapRemove(i); + n--; + } + } + expect(ivf.size).toBe(flat.size); + const q = data[0]!; + const a = flat.search(q, 10); + const b = ivf.search(q, 10, { nprobe: 4 }); + expect(Array.from(b.indices)).toEqual(Array.from(a.indices)); + expect(Array.from(b.scores)).toEqual(Array.from(a.scores)); + }); + + it('round-trips through toBytes/fromBytes with identical search results', () => { + const data = clusteredVecs(4, 20, 8); + const ivf = new TurboQuantIndex({ + dim: IDIM, + metric: 'euclidean', + ivf: { nlist: 4, nprobe: 2 }, + }); + ivf.add(data); + const restored = TurboQuantIndex.fromBytes(ivf.toBytes()); + expect(restored.ivfActive).toBe(true); + expect(restored.size).toBe(ivf.size); + for (const q of data.slice(0, 5)) { + const a = ivf.search(q, 5); + const b = restored.search(q, 5); + expect(Array.from(b.indices)).toEqual(Array.from(a.indices)); + expect(Array.from(b.scores)).toEqual(Array.from(a.scores)); + } + // The restored index accepts further adds and removes without retraining. + restored.addOne(data[0]!); + restored.swapRemove(0); + expect(restored.ivfActive).toBe(true); + }); + + it('honors masks within the probed cells, and clear() keeps the trained quantizer', () => { + const data = clusteredVecs(4, 20, 9); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + ivf.add(data); + const mask = new Uint8Array(80).fill(1); + mask[3] = 0; + const res = ivf.search(data[3]!, 1, { mask, nprobe: 4 }); + expect(res.indices[0]).not.toBe(3); + + ivf.clear(); + expect(ivf.size).toBe(0); + expect(ivf.ivfActive).toBe(true); // centroids survive; decision stays frozen + ivf.add(data.slice(0, 10)); // post-clear adds assign to the existing cells + expect(ivf.search(data[0]!, 1, { nprobe: 4 }).indices.length).toBe(1); + }); + + it('ignores nprobe when IVF is not active; validates it per query when active', () => { + const flat = new TurboQuantIndex({ dim: IDIM }); + flat.add(clusteredVecs(2, 5, 10)); + expect(() => + flat.search(flat.size > 0 ? clusteredVecs(1, 1, 11)[0]! : new Float32Array(IDIM), 2, { + nprobe: 999, + }), + ).not.toThrow(); + + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + ivf.add(clusteredVecs(4, 10, 12)); + let err: unknown; + try { + ivf.search(clusteredVecs(1, 1, 13)[0]!, 2, { nprobe: 5 }); + } catch (e) { + err = e; + } + expect((err as IndexError).code).toBe('INVALID_NPROBE'); + }); +}); + +describe('TurboQuantIndex — IVF input hardening (review follow-ups)', () => { + const IDIM = 32; + + function clusteredVecs(clusters: number, per: number, seed: number): Float32Array[] { + const rng = createRng(seed); + const centers = Array.from({ length: clusters }, () => { + const c = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) c[i] = rng.nextGaussian() * 10; + return c; + }); + const out: Float32Array[] = []; + for (let b = 0; b < clusters; b++) { + for (let j = 0; j < per; j++) { + const v = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) v[i] = centers[b]![i]! + rng.nextGaussian(); + out.push(v); + } + } + return out; + } + + it('rejects malformed queries with the same typed errors as the flat path, before probing', () => { + const data = clusteredVecs(4, 10, 14); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + const flat = new TurboQuantIndex({ dim: IDIM, wasm: false }); + ivf.add(data); + flat.add(data); + const bads: [Float32Array, string][] = [ + [new Float32Array(IDIM - 1), 'INVALID_LENGTH'], // wrong length + [new Float32Array(IDIM).fill(Number.NaN), 'INVALID_LENGTH'], // non-finite + [new Float32Array(IDIM), 'ZERO_QUERY'], // zero query + ]; + for (const [bad, code] of bads) { + for (const idx of [ivf, flat]) { + let err: unknown; + try { + idx.search(bad, 2); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(SearchError); + expect((err as SearchError).code).toBe(code); + } + } + }); + + it('validates the first training batch atomically: a bad row leaves the index unchanged', () => { + const good = clusteredVecs(4, 10, 15); + const nanRow = new Float32Array(IDIM).fill(1); + nanRow[3] = Number.NaN; + for (const poison of [nanRow, new Float32Array(IDIM) /* zero row */]) { + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + let err: unknown; + try { + ivf.add([...good.slice(0, 10), poison, ...good.slice(10)]); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(EncodeError); + expect(ivf.size).toBe(0); // nothing appended + expect(ivf.ivfActive).toBe(false); // nothing trained on the poisoned batch + ivf.add(good); // the decision is NOT frozen by the failed batch + expect(ivf.ivfActive).toBe(true); + expect(ivf.size).toBe(good.length); + } + }); + + it('validates the first calibration batch the same way', () => { + const rng = createRng(16); + const good = Array.from({ length: CALIBRATION_MIN_SAMPLES }, () => { + const v = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) v[i] = rng.nextGaussian(); + return v; + }); + const idx = new TurboQuantIndex({ dim: IDIM, calibrate: true, wasm: false }); + const poisoned = [...good]; + poisoned[7] = new Float32Array(IDIM); // zero row + let err: unknown; + try { + idx.add(poisoned); + } catch (e) { + err = e; + } + expect(err).toBeInstanceOf(EncodeError); + expect(idx.size).toBe(0); + expect(idx.calibrated).toBe(false); + idx.add(good); // decision not frozen; a clean batch still calibrates + expect(idx.calibrated).toBe(true); + }); +}); + +describe('TurboQuantIndex — IVF oracle with tied scores', () => { + const IDIM = 32; + + it('nprobe = nlist matches the flat scan exactly even with duplicate vectors (boundary ties)', () => { + const rng = createRng(77); + // 4 tight clusters; every vector duplicated 3× → exact rankKey ties everywhere, + // including at the k boundary. The top-k heap keeps tied candidates by visit + // order, so the oracle is unconditional only because nprobe = nlist routes to + // the canonical flat scan — this guards that routing against regression. + const base: Float32Array[] = []; + for (let b = 0; b < 4; b++) { + const center = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) center[i] = rng.nextGaussian() * 10; + for (let j = 0; j < 10; j++) { + const v = new Float32Array(IDIM); + for (let i = 0; i < IDIM; i++) v[i] = center[i]! + rng.nextGaussian(); + base.push(v); + } + } + const data: Float32Array[] = []; + for (const v of base) data.push(v, Float32Array.from(v), Float32Array.from(v)); + + const flat = new TurboQuantIndex({ dim: IDIM, wasm: false }); + const ivf = new TurboQuantIndex({ dim: IDIM, ivf: { nlist: 4 } }); + flat.add(data); + ivf.add(data); + for (const q of base.slice(0, 8)) { + for (const k of [2, 4, 7]) { + const a = flat.search(q, k); + const b = ivf.search(q, k, { nprobe: 4 }); + expect(Array.from(b.indices)).toEqual(Array.from(a.indices)); + expect(Array.from(b.scores)).toEqual(Array.from(a.scores)); + } + } + }); +}); diff --git a/src/index/turboquant-index.ts b/src/index/turboquant-index.ts index 1504aa1..5e950d0 100644 --- a/src/index/turboquant-index.ts +++ b/src/index/turboquant-index.ts @@ -22,13 +22,14 @@ import { fitCalibration } from '../core/calibrate'; import type { Calibration } from '../core/calibrate'; import { getCodebook } from '../core/codebook'; import type { Bits, Codebook } from '../core/codebook'; -import { createEncodeScratch, encodeVector } from '../core/encode'; +import { createEncodeScratch, encodeVector, validateVectorBatch } from '../core/encode'; import type { EncodeOptions, EncodeScratch } from '../core/encode'; import { scoreMetric } from '../core/metrics'; import type { Distance, QueryNorms } from '../core/metrics'; import { createRotation } from '../core/rotation'; import type { Rotation } from '../core/rotation'; -import { buildQueryLut, searchFlat, SearchError } from '../core/search'; +import { buildQueryLut, searchFlat, searchSlots, validateQuery, SearchError } from '../core/search'; +import { CoarseQuantizer, defaultNprobe } from './coarse'; import type { EncodedDb, SearchOptions, SearchResult } from '../core/search'; import { TopK } from '../core/topk'; import { WasmKernel } from '../wasm/kernel'; @@ -44,6 +45,8 @@ export class IndexError extends Error { | 'INVALID_LENGTH' | 'INVALID_VECTOR' | 'INVALID_INDEX' + | 'INVALID_NLIST' + | 'INVALID_NPROBE' | 'EMPTY' | 'WRONG_KIND'; constructor(code: IndexError['code'], message: string) { @@ -53,6 +56,22 @@ export class IndexError extends Error { } } +/** IVF (inverted-file) coarse-quantizer options — see {@link TurboQuantIndexOptions.ivf}. */ +export interface IvfOptions { + /** + * Number of coarse cells (posting lists); integer in [2, 2^22]. Rule of thumb: + * ~√n cells for an n-vector corpus; training quality wants a first batch of + * at least ~32·nlist vectors (the hard minimum is nlist). + */ + nlist: number; + /** + * Default number of cells probed per query; integer in [1, nlist]. Higher = + * better recall, slower. Defaults to max(1, ⌈nlist/8⌉). Overridable per query + * via {@link IndexSearchOptions.nprobe}. + */ + nprobe?: number; +} + /** Construction options for {@link TurboQuantIndex}. */ export interface TurboQuantIndexOptions { /** Vector dimension d; must be a positive multiple of 8 (rotation/codebook precondition). */ @@ -92,6 +111,16 @@ export interface TurboQuantIndexOptions { * the exact scan). */ fastscan?: boolean; + /** + * Enable IVF coarse-quantized search (default off). When the first non-empty add + * supplies at least `nlist` vectors, a k-means coarse quantizer is trained from that + * batch and frozen for the index's lifetime; queries then probe only the `nprobe` + * nearest cells instead of scanning all vectors — sublinear work on large corpora. + * A smaller first batch freezes the index flat forever (same contract as + * `calibrate`). While IVF is active the WASM/FastScan whole-database kernels are + * bypassed (the probed-cell scan is scalar; a cell-resident kernel is a future wave). + */ + ivf?: IvfOptions; } /** Per-query options for {@link TurboQuantIndex.search}. */ @@ -100,6 +129,12 @@ export interface IndexSearchOptions { metric?: Distance; /** Optional allowlist (length = size): vector j is scanned only if mask[j] is truthy. */ mask?: Uint8Array | boolean[]; + /** + * Override the IVF probe breadth for this query (integer in [1, nlist]). Ignored + * when IVF is not active — the same call site legitimately runs both before and + * after the training decision freezes. + */ + nprobe?: number; } /** Initial backing-array capacity before the first growth (kept small; doubles on demand). */ @@ -126,6 +161,33 @@ function validateSeed(seed: number): void { } } +/** Upper bound on nlist: keeps the u32 serialization field honest and the centroid + * allocation sane (4M cells is far beyond the library's corpus-scale target). */ +const MAX_NLIST = 1 << 22; + +/** Validate {@link IvfOptions}, returning the resolved default nprobe. */ +function validateIvf(ivf: IvfOptions): number { + const { nlist } = ivf; + if (!Number.isInteger(nlist) || nlist < 2 || nlist > MAX_NLIST) { + throw new IndexError( + 'INVALID_NLIST', + `ivf.nlist must be an integer in [2, 2^22], got ${nlist}`, + ); + } + const nprobe = ivf.nprobe ?? defaultNprobe(nlist); + validateNprobe(nprobe, nlist); + return nprobe; +} + +function validateNprobe(nprobe: number, nlist: number): void { + if (!Number.isInteger(nprobe) || nprobe < 1 || nprobe > nlist) { + throw new IndexError( + 'INVALID_NPROBE', + `nprobe must be an integer in [1, nlist=${nlist}], got ${nprobe}`, + ); + } +} + /** * Length of an array-like vector argument, throwing a typed error if it is not * array-like (e.g. `null`/`undefined` slipping in from untyped JS callers) — keeps a @@ -187,6 +249,12 @@ export class TurboQuantIndex { #wasm: WasmKernel | null | undefined; /** True when the resident WASM codes are stale (mutation since last upload). */ #wasmDirty: boolean; + /** Requested IVF config with the default nprobe resolved, or undefined when off. */ + readonly #ivfOpts: { nlist: number; nprobe: number } | undefined; + /** Trained IVF coarse quantizer, or undefined while flat. */ + #coarse: CoarseQuantizer | undefined; + /** Once true, the IVF training decision is locked for the index's lifetime. */ + #ivfFrozen: boolean; constructor(options: TurboQuantIndexOptions) { const { @@ -197,10 +265,14 @@ export class TurboQuantIndex { calibrate = false, wasm = true, fastscan = false, + ivf, } = options; validateDim(dim); validateBits(bits); validateSeed(seed); + this.#ivfOpts = ivf === undefined ? undefined : { nlist: ivf.nlist, nprobe: validateIvf(ivf) }; + this.#coarse = undefined; + this.#ivfFrozen = false; this.#dim = dim; this.#bits = bits; @@ -253,6 +325,11 @@ export class TurboQuantIndex { return this.#calibration !== undefined; } + /** Whether an IVF coarse quantizer was trained and is in effect. */ + get ivfActive(): boolean { + return this.#coarse !== undefined; + } + /** * Whether the v128 FastScan candidate-pool path is active for this index. This is a * perf-hint constructor option (and requires `bits === 4`) — it is not part of the @@ -301,6 +378,8 @@ export class TurboQuantIndex { this.#norms[slot] = encoded.norm; this.#n = slot + 1; this.#wasmDirty = true; + // The raw (unrotated) vector is exactly what cell assignment ranks against. + if (this.#coarse !== undefined) this.#coarse.addSlot(slot, vec); } /** Normalize an `add`/`addWithIds` batch argument to validated per-vector views. */ @@ -357,6 +436,28 @@ export class TurboQuantIndex { this.#calibrationFrozen = true; } + /** + * @internal Train and freeze the IVF coarse quantizer from the first eligible batch + * (only when the index is still empty, `ivf` is enabled, and the batch has at least + * `nlist` vectors — the hard k-means floor; ≥ ~32·nlist recommended for quality); + * otherwise lock the flat decision. An empty batch is a no-op so it does not + * prematurely freeze. Shared by {@link add} and {@link IdMapIndex}. + */ + trainIvfFromBatch(vecs: readonly Float32Array[]): void { + if (this.#ivfFrozen || this.#n !== 0 || vecs.length === 0) return; + if (this.#ivfOpts !== undefined && vecs.length >= this.#ivfOpts.nlist) { + this.#coarse = CoarseQuantizer.train( + vecs, + this.#ivfOpts.nlist, + this.#ivfOpts.nprobe, + this.#metric, + this.#dim, + this.#seed, + ); + } + this.#ivfFrozen = true; + } + /** * Add one or more vectors. Accepts a flat `Float32Array` of m·dim values (m * vectors laid out row-major), or an array of per-vector `Float32Array` / @@ -368,11 +469,24 @@ export class TurboQuantIndex { * @throws {EncodeError} (re-thrown) on a non-finite or zero vector, or * (`'DEGENERATE'`, calibrated indexes only) a vector so far outside the calibrated * distribution that it cannot be encoded faithfully. Note: a batch is appended in - * order, so an encode error mid-batch leaves the preceding vectors added. + * order, so an encode error mid-batch leaves the preceding vectors added — except + * the first batch of a `calibrate`/`ivf` index, which is validated atomically up + * front (a bad row would otherwise poison the frozen calibration/centroids before + * encode could reject it), so it leaves the index completely unchanged. */ add(vectors: Float32Array | number[][] | Float32Array[]): void { const vecs = this.#toVectorArray(vectors); + // A pending training decision must only ever see a valid batch: a non-finite or + // zero row would poison the calibration fit / k-means centroids that are about to + // be frozen. Validate atomically before fitting (IdMapIndex/Collection already do). + const trainingPending = + this.#n === 0 && + vecs.length > 0 && + ((this.#calibrate && !this.#calibrationFrozen) || + (this.#ivfOpts !== undefined && !this.#ivfFrozen)); + if (trainingPending) validateVectorBatch(vecs); this.fitCalibrationFromBatch(vecs); // first eligible batch fits + freezes TQ+ + this.trainIvfFromBatch(vecs); // first eligible batch trains + freezes IVF this.#ensureCapacity(this.#n + vecs.length); for (let j = 0; j < vecs.length; j++) this.#appendOne(vecs[j]!); } @@ -395,10 +509,13 @@ export class TurboQuantIndex { this.#appendOne(vec instanceof Float32Array ? vec : Float32Array.from(vec)); } - /** Remove all vectors, resetting the live count to 0 (backing capacity is retained). */ + /** Remove all vectors, resetting the live count to 0 (backing capacity is retained). + * A trained IVF quantizer keeps its centroids — clearing data does not unfreeze + * the training decision (same contract as calibration). */ clear(): void { this.#n = 0; this.#wasmDirty = true; + this.#coarse?.clear(); } /** Build a read-only {@link EncodedDb} view over the live rows (no copy). */ @@ -443,6 +560,29 @@ export class TurboQuantIndex { const searchOpts: SearchOptions = opts.mask === undefined ? { metric } : { metric, mask: opts.mask }; + // ── IVF: probe the nprobe nearest cells, scan only their slots ────────── + // The whole-database WASM/FastScan kernels are bypassed while IVF is active + // (the probed-cell scan is scalar; a cell-resident kernel is a future wave). + // `opts.nprobe` is ignored when flat — the same call site legitimately runs + // both before and after the training decision freezes. + if (this.#coarse !== undefined) { + const nprobe = opts.nprobe ?? this.#coarse.defaultNprobe; + validateNprobe(nprobe, this.#coarse.nlist); + // Same typed errors as the flat path, checked BEFORE the centroid probe so a + // malformed query (wrong length / non-finite / zero) never reaches the probe + // arithmetic. searchSlots re-validates via the shared preamble — cheap (O(dim)). + validateQuery(query, this.#dim); + // nprobe = nlist must reproduce the flat scan EXACTLY (the IVF oracle). The + // probed scan visits slots in posting-list order, and the top-k heap drops + // boundary ties, so with duplicate vectors the kept set is order-dependent — + // scan in canonical slot order instead (also skips a pointless full probe). + if (nprobe === this.#coarse.nlist) { + return searchFlat(this.#db(), query, k, searchOpts); + } + const slots = this.#coarse.probe(query, nprobe); + return searchSlots(this.#db(), query, k, slots, searchOpts); + } + if (this.#wasmEnabled) { if (this.#wasm === undefined) this.#wasm = WasmKernel.create(); const kernel = this.#wasm; @@ -576,6 +716,7 @@ export class TurboQuantIndex { } this.#n = last; this.#wasmDirty = true; + this.#coarse?.swapRemove(i, last); } /** @@ -597,6 +738,14 @@ export class TurboQuantIndex { norms: this.#norms.subarray(0, n), }; if (this.#calibration !== undefined) payload.calibration = this.#calibration; + if (this.#coarse !== undefined) { + payload.ivf = { + nlist: this.#coarse.nlist, + nprobe: this.#coarse.defaultNprobe, + centroids: this.#coarse.centroids, + listForSlot: this.#coarse.listForSlotSnapshot(n), + }; + } return payload; } @@ -620,6 +769,18 @@ export class TurboQuantIndex { // Adopt the stored calibration and lock the decision (so later adds don't refit). if (payload.calibration !== undefined) idx.#calibration = payload.calibration; idx.#calibrationFrozen = true; + // Adopt the stored IVF state the same way (postings rebuilt from listForSlot). + if (payload.ivf !== undefined) { + idx.#coarse = CoarseQuantizer.fromState( + payload.ivf.centroids, + payload.ivf.listForSlot, + payload.ivf.nlist, + payload.ivf.nprobe, + payload.metric, + payload.dim, + ); + } + idx.#ivfFrozen = true; return idx; } diff --git a/src/io/serialize.test.ts b/src/io/serialize.test.ts index 31ba4da..648582c 100644 --- a/src/io/serialize.test.ts +++ b/src/io/serialize.test.ts @@ -33,10 +33,10 @@ function fixedEnd(n: number): number { return 24 + (n * DIM * BITS) / 8 + 2 * n * 4; } -/** Start of the ids section: fixed region + the 1-byte calibration-presence flag - * (the test payloads carry no calibration, so the flag is a single 0 byte). */ +/** Start of the ids section: fixed region + the 1-byte calibration-presence flag + + * the 1-byte ivf-presence flag (the test payloads carry neither, so both are 0). */ function idsStart(n: number): number { - return fixedEnd(n) + 1; + return fixedEnd(n) + 2; } function expectDeserializeError(bytes: Uint8Array, code: DeserializeError['code']): void { @@ -127,10 +127,12 @@ describe('deserialize — untrusted input validation', () => { expectDeserializeError(b, 'BAD_MAGIC'); }); - it('rejects an unsupported version', () => { - const b = validPos(); - b[OFF.version] = 2; - expectDeserializeError(b, 'BAD_VERSION'); + it('rejects an unsupported version (v1 buffers included — D-010 bump-and-rewrite)', () => { + for (const v of [1, 3]) { + const b = validPos(); + b[OFF.version] = v; + expectDeserializeError(b, 'BAD_VERSION'); + } }); it('rejects an unknown kind', () => { @@ -309,3 +311,132 @@ describe('serialize — input typing', () => { expect(serializeIndex(payload)).toBeInstanceOf(Uint8Array); }); }); + +// ── IVF section (format v2) ───────────────────────────────────────────────────── + +/** A self-consistent IVF payload for an n-row index with `nlist` cells. */ +function ivfPayload(n: number, nlist = 2): NonNullable { + const centroids = new Float32Array(nlist * DIM); + for (let i = 0; i < centroids.length; i++) centroids[i] = (i % 5) - 2; // exact in f32 + const listForSlot = new Int32Array(n); + for (let i = 0; i < n; i++) listForSlot[i] = i % nlist; + return { nlist, nprobe: 1, centroids, listForSlot }; +} + +/** Byte offset of the ivf flag (fixed region + 1 calibration flag byte). */ +function ivfFlagAt(n: number): number { + return fixedEnd(n) + 1; +} + +describe('serialize/deserialize — ivf section', () => { + it('round-trips ivf state for both kinds, with and without calibration', () => { + const ivf = ivfPayload(3, 2); + const pos: SerializableIndex = { kind: 'positional', ...positionalPayload(3), ivf }; + const parsedPos = deserializeIndex(serializeIndex(pos)); + expect(parsedPos.ivf).toBeDefined(); + expect(parsedPos.ivf!.nlist).toBe(2); + expect(parsedPos.ivf!.nprobe).toBe(1); + expect(Array.from(parsedPos.ivf!.centroids)).toEqual(Array.from(ivf.centroids)); + expect(Array.from(parsedPos.ivf!.listForSlot)).toEqual(Array.from(ivf.listForSlot)); + + const calibration = { + shift: new Float32Array(DIM).fill(0.25), + scale: new Float32Array(DIM).fill(1.5), + }; + const both: SerializableIndex = { + kind: 'idmap', + ...positionalPayload(3), + calibration, + ivf, + ids: [7, 8, 9], + }; + const parsedBoth = deserializeIndex(serializeIndex(both)); + expect(parsedBoth.calibration).toBeDefined(); + expect(parsedBoth.ivf).toBeDefined(); + expect((parsedBoth as { ids: IdType[] }).ids).toEqual([7, 8, 9]); + expect(Array.from(parsedBoth.ivf!.listForSlot)).toEqual([0, 1, 0]); + }); + + it('round-trips an empty (n = 0) index with ivf present — trained then cleared', () => { + const parsed = deserializeIndex( + serializeIndex({ kind: 'positional', ...positionalPayload(0), ivf: ivfPayload(0, 2) }), + ); + expect(parsed.ivf!.listForSlot.length).toBe(0); + expect(parsed.ivf!.centroids.length).toBe(2 * DIM); + }); + + it('omits the section cleanly when absent (flag 0)', () => { + const parsed = deserializeIndex( + serializeIndex({ kind: 'positional', ...positionalPayload(2) }), + ); + expect(parsed.ivf).toBeUndefined(); + }); + + function craftedIvf( + mutate: (b: Uint8Array, dv: DataView, flagAt: number) => Uint8Array | void, + ): Uint8Array { + const bytes = serializeIndex({ + kind: 'positional', + ...positionalPayload(2), + ivf: ivfPayload(2, 2), + }); + const dv = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength); + return mutate(bytes, dv, ivfFlagAt(2)) ?? bytes; + } + + it('rejects an invalid ivf flag byte', () => { + expectDeserializeError( + craftedIvf((b, _dv, at) => void (b[at] = 7)), + 'BAD_IVF', + ); + }); + + it('rejects truncation before nlist/nprobe', () => { + expectDeserializeError( + craftedIvf((b, _dv, at) => b.slice(0, at + 3)), + 'BAD_IVF', + ); + }); + + it.each([0, 1, (1 << 22) + 1])('rejects nlist = %s out of [2, 2^22]', (nlist) => { + expectDeserializeError( + craftedIvf((_b, dv, at) => void dv.setUint32(at + 1, nlist, true)), + 'BAD_IVF', + ); + }); + + it.each([0, 3])('rejects nprobe = %s outside [1, nlist = 2]', (nprobe) => { + expectDeserializeError( + craftedIvf((_b, dv, at) => void dv.setUint32(at + 5, nprobe, true)), + 'BAD_IVF', + ); + }); + + it('rejects a truncated centroids/listForSlot region', () => { + expectDeserializeError( + craftedIvf((b, _dv, at) => b.slice(0, at + 9 + 4)), + 'BAD_IVF', + ); + }); + + it('rejects a non-finite centroid coordinate', () => { + expectDeserializeError( + craftedIvf((_b, dv, at) => void dv.setFloat32(at + 9, Number.NaN, true)), + 'BAD_IVF', + ); + }); + + it('rejects a listForSlot entry >= nlist', () => { + expectDeserializeError( + craftedIvf((_b, dv, at) => void dv.setUint32(at + 9 + 2 * DIM * 4, 2, true)), + 'BAD_IVF', + ); + }); + + it('rejects trailing bytes after the ivf section (positional)', () => { + const ok = craftedIvf(() => undefined); + const extended = new Uint8Array(ok.length + 3); + extended.set(ok); + expectDeserializeError(extended, 'BAD_LENGTH'); + }); +}); diff --git a/src/io/serialize.ts b/src/io/serialize.ts index 991cd49..cf61452 100644 --- a/src/io/serialize.ts +++ b/src/io/serialize.ts @@ -11,14 +11,17 @@ // // Layout (all multi-byte fields little-endian; header = 24 bytes): // [0..4) magic "QVEC" (0x51 0x56 0x45 0x43) -// [4] version u8 (= 1) +// [4] version u8 (= 2) // [5] kind u8 (0 = positional, 1 = idmap) // [6] metric u8 (0 = dot, 1 = cosine, 2 = euclidean) // [7] bits u8 (2 | 3 | 4) // [8..12) dim u32 // [12..16) n u32 // [16..24) seed f64 -// [24..] codes (n·dim bytes) · scales (n·f32) · norms (n·f32) +// [24..] codes (⌈n·dim·bits/8⌉ bytes, bit-packed) · scales (n·f32) · norms (n·f32) +// [cali] flag u8 ∈ {0,1} ; 1 → shift (dim·f32) + scale (dim·f32) +// [ivf] flag u8 ∈ {0,1} ; 1 → nlist u32 + nprobe u32 +// + centroids (nlist·dim·f32) + listForSlot (n·u32) // [idmap] ids: n × { tag u8 ; 0→f64 | 1→(u32 len + utf8) | 2→(u32 len + utf8 of BigInt) } import type { Calibration } from '../core/calibrate'; @@ -42,7 +45,8 @@ export class DeserializeError extends Error { | 'BAD_SEED' | 'BAD_LENGTH' | 'BAD_ID' - | 'BAD_CALIBRATION'; + | 'BAD_CALIBRATION' + | 'BAD_IVF'; constructor(code: DeserializeError['code'], message: string) { super(message); this.name = 'DeserializeError'; @@ -70,6 +74,20 @@ export interface IndexPayload { norms: Float32Array; /** Optional frozen TQ+ calibration (shift/scale, length dim each); absent if un-calibrated. */ calibration?: Calibration; + /** Optional frozen IVF coarse-quantizer state; absent while the index is flat. */ + ivf?: IvfPayload; +} + +/** Serialized IVF coarse-quantizer state (postings are rebuilt from `listForSlot`). */ +export interface IvfPayload { + /** Number of coarse cells; integer in [2, 2^22]. */ + nlist: number; + /** Default probe breadth; integer in [1, nlist]. */ + nprobe: number; + /** Row-major nlist·dim cell centroids. */ + centroids: Float32Array; + /** Owning cell per live slot, length n (each entry < nlist). */ + listForSlot: Int32Array; } /** Serialize input for the positional index. */ @@ -102,8 +120,10 @@ export interface DeserializedIdMap extends IndexPayload { export type DeserializedIndex = DeserializedPositional | DeserializedIdMap; const MAGIC = 0x51564543; // "QVEC" read big-endian as a u32 (matches the byte order written below) -const VERSION = 1; +const VERSION = 2; const HEADER_BYTES = 24; +/** Same bound the index enforces at construction (see ../index/turboquant-index). */ +const MAX_NLIST = 1 << 22; const FLOAT_BYTES = 4; // width of an f32 scale/norm field const U32_BYTES = 4; // width of the u32 length prefix on a string/bigint id @@ -144,7 +164,11 @@ export function serializeIndex(payload: SerializableIndex): Uint8Array { // Calibration section: 1 presence byte, then (if present) shift[dim] + scale[dim] f32. const calibration = payload.calibration; const caliBytes = 1 + (calibration ? 2 * dim * FLOAT_BYTES : 0); - const bodyBytes = packedCodes.length + (scales.length + norms.length) * FLOAT_BYTES + caliBytes; + // IVF section: 1 presence byte, then (if present) nlist + nprobe u32, centroids, listForSlot. + const ivf = payload.ivf; + const ivfBytes = 1 + (ivf ? 2 * U32_BYTES + ivf.nlist * dim * FLOAT_BYTES + n * U32_BYTES : 0); + const bodyBytes = + packedCodes.length + (scales.length + norms.length) * FLOAT_BYTES + caliBytes + ivfBytes; // For the id-keyed index, pre-encode the ids so we can size the buffer exactly. const enc = new TextEncoder(); @@ -194,6 +218,27 @@ export function serializeIndex(payload: SerializableIndex): Uint8Array { off += 1; } + // ── IVF (presence byte, then nlist/nprobe + centroids + listForSlot if present) ── + if (ivf) { + dv.setUint8(off, 1); + off += 1; + dv.setUint32(off, ivf.nlist, true); + off += U32_BYTES; + dv.setUint32(off, ivf.nprobe, true); + off += U32_BYTES; + for (let i = 0; i < ivf.centroids.length; i++) { + dv.setFloat32(off, ivf.centroids[i]!, true); + off += FLOAT_BYTES; + } + for (let i = 0; i < ivf.listForSlot.length; i++) { + dv.setUint32(off, ivf.listForSlot[i]!, true); + off += U32_BYTES; + } + } else { + dv.setUint8(off, 0); + off += 1; + } + // ── Ids (idmap only) ──────────────────────────────────────────────────────── if (encodedIds) { const ids = (payload as IdMapPayload).ids; @@ -337,8 +382,61 @@ export function deserializeIndex(bytes: Uint8Array): DeserializedIndex { throw new DeserializeError('BAD_LENGTH', `invalid calibration flag ${caliFlag}`); } + // ── IVF (presence byte, then nlist/nprobe + centroids + listForSlot if present) ── + if (off + 1 > bytes.length) { + throw new DeserializeError('BAD_LENGTH', 'truncated before ivf flag'); + } + const ivfFlag = dv.getUint8(off); + off += 1; + let ivf: IvfPayload | undefined; + if (ivfFlag === 1) { + if (off + 2 * U32_BYTES > bytes.length) { + throw new DeserializeError('BAD_IVF', 'truncated before ivf nlist/nprobe'); + } + const nlist = dv.getUint32(off, true); + off += U32_BYTES; + const nprobe = dv.getUint32(off, true); + off += U32_BYTES; + if (nlist < 2 || nlist > MAX_NLIST) { + throw new DeserializeError('BAD_IVF', `nlist must be in [2, 2^22], got ${nlist}`); + } + if (nprobe < 1 || nprobe > nlist) { + throw new DeserializeError('BAD_IVF', `nprobe must be in [1, nlist=${nlist}], got ${nprobe}`); + } + // Bounds-check the whole section BEFORE allocating anything sized by nlist/n. + const ivfBody = nlist * dim * FLOAT_BYTES + n * U32_BYTES; + if (off + ivfBody > bytes.length) { + throw new DeserializeError('BAD_IVF', 'ivf section exceeds buffer'); + } + const centroids = new Float32Array(nlist * dim); + for (let i = 0; i < centroids.length; i++) { + const x = dv.getFloat32(off, true); + if (!Number.isFinite(x)) { + throw new DeserializeError('BAD_IVF', `centroid coordinate ${i} is not finite`); + } + centroids[i] = x; + off += FLOAT_BYTES; + } + const listForSlot = new Int32Array(n); + for (let i = 0; i < n; i++) { + const l = dv.getUint32(off, true); + if (l >= nlist) { + throw new DeserializeError( + 'BAD_IVF', + `listForSlot[${i}] = ${l} out of range [0, ${nlist})`, + ); + } + listForSlot[i] = l; + off += U32_BYTES; + } + ivf = { nlist, nprobe, centroids, listForSlot }; + } else if (ivfFlag !== 0) { + throw new DeserializeError('BAD_IVF', `invalid ivf flag ${ivfFlag}`); + } + const base: IndexPayload = { metric, bits: bits as Bits, dim, n, seed, codes, scales, norms }; if (calibration !== undefined) base.calibration = calibration; + if (ivf !== undefined) base.ivf = ivf; if (kindByte === 0) { if (off !== bytes.length) { throw new DeserializeError('BAD_LENGTH', `${bytes.length - off} trailing bytes after body`);