From df817001928f57ff7738527f58fa651e359ec238 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Mon, 27 Apr 2026 17:17:50 -0700 Subject: [PATCH 1/7] Add dedicated Arrow CSR result type --- src_cpp/include/py_connection.h | 2 + src_cpp/include/py_query_result.h | 1 + src_cpp/py_connection.cpp | 10 +++++ src_cpp/py_query_result.cpp | 55 ++++++++++++++++++++++++ src_py/__init__.py | 4 +- src_py/_lbug_capi.py | 3 ++ src_py/connection.py | 23 +++++++++- src_py/query_result.py | 52 +++++++++++++++++++++++ test/test_arrow.py | 70 +++++++++++++++++++++++++++++++ 9 files changed, 218 insertions(+), 2 deletions(-) diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index b05a087..2817f87 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -29,6 +29,8 @@ class PyConnection { const py::dict& params); std::unique_ptr query(const std::string& statement); + std::unique_ptr queryAsArrow(const std::string& statement, + int64_t chunkSize); void setMaxNumThreadForExec(uint64_t numThreads); diff --git a/src_cpp/include/py_query_result.h b/src_cpp/include/py_query_result.h index 4243bdf..dfec9ab 100644 --- a/src_cpp/include/py_query_result.h +++ b/src_cpp/include/py_query_result.h @@ -35,6 +35,7 @@ class PyQueryResult { py::object getAsDF(); lbug::pyarrow::Table getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes); + py::dict getCSR(); py::list getColumnDataTypes(); diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index f8a7139..1abf5ea 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -31,6 +31,8 @@ void PyConnection::initialize(py::handle& m) { .def("execute", &PyConnection::execute, py::arg("prepared_statement"), py::arg("parameters") = py::dict()) .def("query", &PyConnection::query, py::arg("statement")) + .def("query_as_arrow", &PyConnection::queryAsArrow, py::arg("statement"), + py::arg("chunk_size")) .def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec, py::arg("num_threads")) .def("prepare", &PyConnection::prepare, py::arg("query"), @@ -175,6 +177,14 @@ std::unique_ptr PyConnection::query(const std::string& statement) return checkAndWrapQueryResult(queryResult); } +std::unique_ptr PyConnection::queryAsArrow(const std::string& statement, + int64_t chunkSize) { + py::gil_scoped_release release; + auto queryResult = conn->queryAsArrow(statement, chunkSize); + py::gil_scoped_acquire acquire; + return checkAndWrapQueryResult(queryResult); +} + void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) { conn->setMaxNumThreadForExec(numThreads); } diff --git a/src_cpp/py_query_result.cpp b/src_cpp/py_query_result.cpp index b3e0db8..01d65de 100644 --- a/src_cpp/py_query_result.cpp +++ b/src_cpp/py_query_result.cpp @@ -7,12 +7,14 @@ #include "common/arrow/arrow_row_batch.h" #include "common/constants.h" #include "common/exception/not_implemented.h" +#include "common/exception/runtime.h" #include "common/types/uuid.h" #include "common/types/value/nested.h" #include "common/types/value/node.h" #include "common/types/value/rel.h" #include "datetime.h" // python lib #include "include/py_query_result_converter.h" +#include "main/query_result/arrow_query_result.h" using namespace lbug::common; using lbug::importCache; @@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) { .def("close", &PyQueryResult::close) .def("getAsDF", &PyQueryResult::getAsDF) .def("getAsArrow", &PyQueryResult::getAsArrow) + .def("getCSR", &PyQueryResult::getCSR) .def("getColumnNames", &PyQueryResult::getColumnNames) .def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes) .def("resetIterator", &PyQueryResult::resetIterator) @@ -85,6 +88,30 @@ void PyQueryResult::close() { } } +namespace { + +py::array_t copyToNumpyArray(const std::vector& values) { + auto result = py::array_t(values.size()); + auto* data = static_cast(result.request().ptr); + std::copy(values.begin(), values.end(), data); + return result; +} + +py::dict buildCSRResult(std::vector indptr, std::vector indices, + std::vector edgeIDs, bool includeEdgeIDs) { + py::dict result; + result["indptr"] = copyToNumpyArray(indptr); + result["indices"] = copyToNumpyArray(indices); + if (includeEdgeIDs) { + result["edge_ids"] = copyToNumpyArray(edgeIDs); + } else { + result["edge_ids"] = py::none(); + } + return result; +} + +} // namespace + static py::object converTimestampToPyObject(timestamp_t& timestamp) { int32_t year = 0, month = 0, day = 0, hour = 0, min = 0, sec = 0, micros = 0; date_t date; @@ -320,6 +347,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector& types, lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes) { + if (queryResult->getType() == QueryResultType::ARROW) { + auto types = queryResult->getColumnDataTypes(); + auto names = queryResult->getColumnNames(); + py::list batches; + auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c(); + while (queryResult->hasNextArrowChunk()) { + auto data = queryResult->getNextArrowChunk(chunkSize); + auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); + batches.append( + batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get())); + } + auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); + auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches(); + auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c(); + auto schemaObj = schemaImportFunc((std::uint64_t)schema.get()); + return py::cast(fromBatchesFunc(batches, schemaObj)); + } auto types = queryResult->getColumnDataTypes(); auto names = queryResult->getColumnNames(); py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes); @@ -330,6 +374,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, return py::cast(fromBatchesFunc(batches, schemaObj)); } +py::dict PyQueryResult::getCSR() { + if (auto* arrowQueryResult = dynamic_cast(queryResult); + arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) { + const auto& metadata = arrowQueryResult->getCSRMetadata(); + return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs, + metadata.hasEdgeIDs); + } + throw RuntimeException( + "CSR export is only supported for Arrow query results with native CSR metadata."); +} + py::list PyQueryResult::getColumnDataTypes() { auto columnDataTypes = queryResult->getColumnDataTypes(); py::tuple result(columnDataTypes.size()); diff --git a/src_py/__init__.py b/src_py/__init__.py index 6a60db1..782cbda 100644 --- a/src_py/__init__.py +++ b/src_py/__init__.py @@ -56,7 +56,7 @@ from .connection import Connection # noqa: E402 from .database import Database # noqa: E402 from .prepared_statement import PreparedStatement # noqa: E402 -from .query_result import QueryResult # noqa: E402 +from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402 from .types import Type # noqa: E402 _VERSION_INFO: tuple[str, int] | None = None @@ -80,7 +80,9 @@ def __getattr__(name: str) -> str | int: __all__ = [ "AsyncConnection", + "ArrowQueryResult", "Connection", + "CSRResult", "Database", "PreparedStatement", "QueryResult", diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 75e4f80..423dd04 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -1229,6 +1229,9 @@ def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any: "Arrow export is not yet implemented in C-API backend" ) + def getCSR(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError("CSR export is not yet implemented in C-API backend") + def getAsDF(self) -> Any: raise NotImplementedError( "DataFrame export is not yet implemented in C-API backend" diff --git a/src_py/connection.py b/src_py/connection.py index 6f47b3a..0fad143 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -8,7 +8,7 @@ from ._backend import get_capi_module, get_pybind_module from .prepared_statement import PreparedStatement -from .query_result import QueryResult +from .query_result import ArrowQueryResult, QueryResult if TYPE_CHECKING: import sys @@ -369,6 +369,27 @@ def execute( all_query_results.append(next_query_result) return all_query_results + def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult: + """ + Execute a query with the native Arrow collector path. + + This is the efficient path for CSR-aware Arrow export. + """ + self.init_connection() + if not self._using_pybind_backend(): + msg = "query_as_arrow requires the pybind backend" + raise NotImplementedError(msg) + query_result_internal = self._get_pybind_connection().query_as_arrow( + query, chunk_size + ) + if not query_result_internal.isSuccess(): + raise RuntimeError(query_result_internal.getErrorMessage()) + current_query_result = ArrowQueryResult( + self, query_result_internal, native_chunk_size=chunk_size + ) + self._register_query_result(current_query_result) + return current_query_result + def _prepare( self, query: str, diff --git a/src_py/query_result.py b/src_py/query_result.py index 12cd8d6..e9e454a 100644 --- a/src_py/query_result.py +++ b/src_py/query_result.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING from .constants import DST, ID, LABEL, NODES, RELS, SRC @@ -525,6 +526,57 @@ def rows_as_dict(self, state=True) -> Self: return self +class ArrowQueryResult(QueryResult): + """QueryResult backed by the native Arrow collector path.""" + + def __init__( + self, connection: Any, query_result: Any, native_chunk_size: int + ) -> None: + super().__init__(connection, query_result) + self._native_chunk_size = native_chunk_size + + def get_as_arrow( + self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False + ) -> pa.Table: + """ + Get the query result as a PyArrow Table. + + Arrow-native results preserve the execution-time chunking chosen by + `Connection.query_as_arrow(...)`. Requesting `None`, `0`, or `-1` + reuses that native chunk size instead of rechunking the result. + """ + if chunk_size is None or chunk_size <= 0: + chunk_size = self._native_chunk_size + return super().get_as_arrow( + chunk_size, fallbackExtensionTypes=fallbackExtensionTypes + ) + + def csr(self) -> CSRResult: + """ + Get native CSR arrays from an Arrow query result. + + This is available only for Arrow results with CSR metadata, typically + from `Connection.query_as_arrow(...)` on relationship-shaped projections. + """ + self.check_for_query_result_close() + + import pyarrow as pa + + csr = self._query_result.getCSR() + return CSRResult( + indptr=pa.array(csr["indptr"]), + indices=pa.array(csr["indices"]), + edge_ids=(None if csr["edge_ids"] is None else pa.array(csr["edge_ids"])), + ) + + +@dataclass(frozen=True) +class CSRResult: + indptr: pa.Array + indices: pa.Array + edge_ids: pa.Array | None = None + + def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]: if len(columns) != len(row): msg = "Number of columns in output row does not match number of columns" diff --git a/test/test_arrow.py b/test/test_arrow.py index 72c7af2..3eb784e 100644 --- a/test/test_arrow.py +++ b/test/test_arrow.py @@ -772,3 +772,73 @@ def test_to_arrow1(conn: lb.Connection) -> None: -1 ) # what is a chunk size of -1 even supposed to mean? assert arrow_tbl == [] + + +def test_query_as_arrow_csr_with_rel_ids(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[b:knows]->(c:person) + RETURN a.rowid, b.rowid, c.rowid + """ + rows = conn.execute(query).get_all() + csr = conn.query_as_arrow(query, 8).csr() + + assert csr.edge_ids is not None + + reconstructed = [] + indptr = csr.indptr.to_pylist() + indices = csr.indices.to_pylist() + edge_ids = csr.edge_ids.to_pylist() + for src_rowid in range(len(indptr) - 1): + for idx in range(indptr[src_rowid], indptr[src_rowid + 1]): + reconstructed.append([src_rowid, edge_ids[idx], indices[idx]]) + + assert reconstructed == rows + + +def test_query_as_arrow_csr_with_extra_columns(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[b:knows]->(c:person) + RETURN a.rowid, b.rowid, c.rowid, b.date, c.fName + """ + result = conn.query_as_arrow(query, 8) + csr = result.csr() + arrow_tbl = result.get_as_arrow(0) + + assert csr.edge_ids is not None + assert arrow_tbl.column_names == [ + "a.rowid", + "b.rowid", + "c.rowid", + "b.date", + "c.fName", + ] + assert len(csr.indptr) >= 2 + + +def test_query_as_arrow_csr_without_rel_ids(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[:knows]->(c:person) + RETURN a.rowid, c.rowid + """ + rows = conn.execute(query).get_all() + csr = conn.query_as_arrow(query, 8).csr() + + assert csr.edge_ids is None + + reconstructed = [] + indptr = csr.indptr.to_pylist() + indices = csr.indices.to_pylist() + for src_rowid in range(len(indptr) - 1): + for idx in range(indptr[src_rowid], indptr[src_rowid + 1]): + reconstructed.append([src_rowid, indices[idx]]) + + assert reconstructed == rows + + +def test_query_as_arrow_csr_rejects_non_csr_shape(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + with pytest.raises(RuntimeError, match="CSR export is only supported"): + conn.query_as_arrow("MATCH (a:person) RETURN a.fName", 8).csr() From 3013a21b49129f2fa93d9a0c96d67f49ac1a2cac Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 6 May 2026 13:25:49 -0700 Subject: [PATCH 2/7] Fix pre-commit --- .github/workflows/ci.yml | 4 ++-- src_py/query_result.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 317952d..10d0cb2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -113,7 +113,7 @@ jobs: - name: Check formatting (black) working-directory: ladybug/tools/python_api run: | - uv pip install black + uv pip install black==26.3.0 .venv/bin/black --check src_py test - name: Run ruff check @@ -177,7 +177,7 @@ jobs: - name: Check formatting (black) working-directory: ladybug/tools/python_api run: | - uv pip install black + uv pip install black==26.3.0 .venv/bin/black --check src_py test - name: Run ruff check diff --git a/src_py/query_result.py b/src_py/query_result.py index e9e454a..93c0b65 100644 --- a/src_py/query_result.py +++ b/src_py/query_result.py @@ -572,6 +572,8 @@ def csr(self) -> CSRResult: @dataclass(frozen=True) class CSRResult: + """Native CSR arrays returned by an Arrow query result.""" + indptr: pa.Array indices: pa.Array edge_ids: pa.Array | None = None From 9cadde564a930ac5b3c49e778a94911c5eec5d68 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 6 May 2026 13:25:56 -0700 Subject: [PATCH 3/7] Add pre-commit --- .pre-commit-config.yaml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ea0434f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/psf/black + rev: 26.3.0 + hooks: + - id: black + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.12 + hooks: + - id: ruff-check From c0c4ab74a1603cfbf5ba32b0279fbf15322cc38d Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 6 May 2026 13:50:15 -0700 Subject: [PATCH 4/7] Skip C API CSR Arrow tests --- test/capi_xfails.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/capi_xfails.py b/test/capi_xfails.py index e963585..036d976 100644 --- a/test/capi_xfails.py +++ b/test/capi_xfails.py @@ -6,6 +6,10 @@ "test/test_arrow.py::test_to_arrow_map", "test/test_arrow.py::test_to_arrow_array", "test/test_arrow.py::test_to_arrow_complex", + "test/test_arrow.py::test_query_as_arrow_csr_with_rel_ids", + "test/test_arrow.py::test_query_as_arrow_csr_with_extra_columns", + "test/test_arrow.py::test_query_as_arrow_csr_without_rel_ids", + "test/test_arrow.py::test_query_as_arrow_csr_rejects_non_csr_shape", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas", From 5dbfff79213e310ec30fbf4535ae9c8e71895e37 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 6 May 2026 13:57:27 -0700 Subject: [PATCH 5/7] Statically link JSON extension in pybind CI --- .github/workflows/ci.yml | 3 ++- test/test_json.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10d0cb2..3b34cd5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -140,7 +140,7 @@ jobs: - name: Update submodules working-directory: ladybug - run: git submodule update --init --recursive dataset + run: git submodule update --init --recursive dataset extension - name: Checkout ladybug-python into ladybug/tools/python_api uses: actions/checkout@v4 @@ -191,6 +191,7 @@ jobs: GEN: Ninja CMAKE_C_COMPILER_LAUNCHER: ccache CMAKE_CXX_COMPILER_LAUNCHER: ccache + EXTRA_CMAKE_FLAGS: -DBUILD_EXTENSIONS=json -DEXTENSION_STATIC_LINK_LIST=json run: | make python cp tools/python_api/src_py/*.py tools/python_api/build/ladybug/ diff --git a/test/test_json.py b/test/test_json.py index f10e8b1..2e57175 100644 --- a/test/test_json.py +++ b/test/test_json.py @@ -11,8 +11,6 @@ def test_to_json_string_param_roundtrip(conn_db_empty: ConnDB) -> None: """to_json() with a JSON string parameter should store the parsed object, not a string literal.""" conn, _ = conn_db_empty conn.execute(""" - INSTALL json; - LOAD json; CREATE NODE TABLE User (id SERIAL PRIMARY KEY, meta JSON); """) From a655e4d0db6c0aee2a711c430fbe305b974ca7c8 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 6 May 2026 16:11:07 -0700 Subject: [PATCH 6/7] Return CSR metadata as PyArrow arrays --- .../include/cached_import/py_cached_modules.h | 12 +++++++++- src_cpp/py_query_result.cpp | 23 ++++++++----------- src_py/query_result.py | 8 +++---- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src_cpp/include/cached_import/py_cached_modules.h b/src_cpp/include/cached_import/py_cached_modules.h index f595692..dad381a 100644 --- a/src_cpp/include/cached_import/py_cached_modules.h +++ b/src_cpp/include/cached_import/py_cached_modules.h @@ -103,6 +103,14 @@ class PolarsCachedItem : public PythonCachedItem { }; class PyarrowCachedItem : public PythonCachedItem { + class ArrayCachedItem : public PythonCachedItem { + public: + explicit ArrayCachedItem(PythonCachedItem* parent) + : PythonCachedItem("Array", parent), _import_from_c("_import_from_c", this) {} + + PythonCachedItem _import_from_c; + }; + class RecordBatchCachedItem : public PythonCachedItem { public: explicit RecordBatchCachedItem(PythonCachedItem* parent) @@ -132,8 +140,10 @@ class PyarrowCachedItem : public PythonCachedItem { class LibCachedItem : public PythonCachedItem { public: explicit LibCachedItem(PythonCachedItem* parent) - : PythonCachedItem("lib", parent), RecordBatch(this), Schema(this), Table(this) {} + : PythonCachedItem("lib", parent), Array(this), RecordBatch(this), Schema(this), + Table(this) {} + ArrayCachedItem Array; RecordBatchCachedItem RecordBatch; SchemaCachedItem Schema; TableCachedItem Table; diff --git a/src_cpp/py_query_result.cpp b/src_cpp/py_query_result.cpp index 01d65de..fed2ae7 100644 --- a/src_cpp/py_query_result.cpp +++ b/src_cpp/py_query_result.cpp @@ -90,20 +90,17 @@ void PyQueryResult::close() { namespace { -py::array_t copyToNumpyArray(const std::vector& values) { - auto result = py::array_t(values.size()); - auto* data = static_cast(result.request().ptr); - std::copy(values.begin(), values.end(), data); - return result; +py::object importCSRArrowArray(lbug::main::ArrowQueryResult::CSRArrowArray& array) { + auto arrayImportFunc = importCache->pyarrow.lib.Array._import_from_c(); + return arrayImportFunc((std::uint64_t)&array.array, (std::uint64_t)&array.schema); } -py::dict buildCSRResult(std::vector indptr, std::vector indices, - std::vector edgeIDs, bool includeEdgeIDs) { +py::dict buildCSRResult(lbug::main::ArrowQueryResult::CSRArrowArrays arrays) { py::dict result; - result["indptr"] = copyToNumpyArray(indptr); - result["indices"] = copyToNumpyArray(indices); - if (includeEdgeIDs) { - result["edge_ids"] = copyToNumpyArray(edgeIDs); + result["indptr"] = importCSRArrowArray(arrays.indptr); + result["indices"] = importCSRArrowArray(arrays.indices); + if (arrays.edgeIDs.has_value()) { + result["edge_ids"] = importCSRArrowArray(*arrays.edgeIDs); } else { result["edge_ids"] = py::none(); } @@ -377,9 +374,7 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, py::dict PyQueryResult::getCSR() { if (auto* arrowQueryResult = dynamic_cast(queryResult); arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) { - const auto& metadata = arrowQueryResult->getCSRMetadata(); - return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs, - metadata.hasEdgeIDs); + return buildCSRResult(arrowQueryResult->getCSRArrowArrays()); } throw RuntimeException( "CSR export is only supported for Arrow query results with native CSR metadata."); diff --git a/src_py/query_result.py b/src_py/query_result.py index 93c0b65..ba8d31e 100644 --- a/src_py/query_result.py +++ b/src_py/query_result.py @@ -560,13 +560,11 @@ def csr(self) -> CSRResult: """ self.check_for_query_result_close() - import pyarrow as pa - csr = self._query_result.getCSR() return CSRResult( - indptr=pa.array(csr["indptr"]), - indices=pa.array(csr["indices"]), - edge_ids=(None if csr["edge_ids"] is None else pa.array(csr["edge_ids"])), + indptr=csr["indptr"], + indices=csr["indices"], + edge_ids=csr["edge_ids"], ) From 1dc045cab06ba56855976f027f175cb4e0d6b5e8 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Thu, 7 May 2026 12:53:16 -0700 Subject: [PATCH 7/7] Add tests for mixed type arrow tables --- test/capi_xfails.py | 2 + test/test_arrow_memory_backed_table.py | 129 +++++++++++++++++++++++++ 2 files changed, 131 insertions(+) diff --git a/test/capi_xfails.py b/test/capi_xfails.py index 036d976..c050e2e 100644 --- a/test/capi_xfails.py +++ b/test/capi_xfails.py @@ -16,6 +16,8 @@ "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pyarrow", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_empty_result", "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count", + "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table", + "test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table", "test/test_async_connection.py::test_async_scan_df", "test/test_blob_parameter.py::test_bytes_param_udf", "test/test_df.py::test_to_df", diff --git a/test/test_arrow_memory_backed_table.py b/test/test_arrow_memory_backed_table.py index a3d4df9..abd4680 100644 --- a/test/test_arrow_memory_backed_table.py +++ b/test/test_arrow_memory_backed_table.py @@ -273,3 +273,132 @@ def test_arrow_memory_backed_table_count(conn_db_empty: ConnDB) -> None: # Clean up conn.drop_arrow_table("transactions") + + +def test_arrow_memory_backed_arrow_node_and_rel_table(conn_db_empty: ConnDB) -> None: + """Test an Arrow memory-backed relationship over Arrow-backed nodes.""" + conn, _ = conn_db_empty + + pa = pytest.importorskip("pyarrow") + + people = pa.Table.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int64()), + pa.array(["Alice", "Bob", "Carol"], type=pa.string()), + ], + names=["id", "name"], + ) + conn.create_arrow_table("arrow_people", people) + + knows = pa.Table.from_arrays( + [ + pa.array([1, 1, 2], type=pa.int64()), + pa.array([2, 3, 3], type=pa.int64()), + pa.array([10, 20, 30], type=pa.int64()), + ], + names=["from", "to", "weight"], + ) + conn.create_arrow_rel_table("arrow_knows", knows, "arrow_people", "arrow_people") + + result = conn.execute( + "MATCH (a:arrow_people)-[r:arrow_knows]->(b:arrow_people) " + "RETURN a.name, b.name, r.weight ORDER BY a.id, b.id" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + ["Alice", "Bob", 10], + ["Alice", "Carol", 20], + ["Bob", "Carol", 30], + ] + + result = conn.execute( + "MATCH (:arrow_people)-[r:arrow_knows]->(:arrow_people) " + "RETURN COUNT(*), SUM(r.weight)" + ) + assert result.get_next() == [3, 60] + assert not result.has_next() + + result = conn.execute( + "MATCH (a:arrow_people)-[r:arrow_knows]->(b:arrow_people) " + "WHERE r.weight >= 20 " + "RETURN a.name, b.name, r.weight ORDER BY r.weight" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + ["Alice", "Carol", 20], + ["Bob", "Carol", 30], + ] + + conn.drop_arrow_table("arrow_knows") + conn.drop_arrow_table("arrow_people") + + +def test_arrow_memory_backed_native_node_and_arrow_rel_table( + conn_db_empty: ConnDB, +) -> None: + """Test an Arrow memory-backed relationship over native node tables.""" + conn, _ = conn_db_empty + + pa = pytest.importorskip("pyarrow") + + conn.execute( + "CREATE NODE TABLE native_people(id INT64, name STRING, PRIMARY KEY(id));" + "CREATE (:native_people {id: 1, name: 'Alice'});" + "CREATE (:native_people {id: 2, name: 'Bob'});" + "CREATE (:native_people {id: 3, name: 'Carol'});" + ) + + knows = pa.Table.from_arrays( + [ + pa.array([1, 1, 2], type=pa.int64()), + pa.array([2, 3, 3], type=pa.int64()), + pa.array([10, 20, 30], type=pa.int64()), + ], + names=["from", "to", "weight"], + ) + conn.create_arrow_rel_table( + "native_people_arrow_knows", knows, "native_people", "native_people" + ) + + result = conn.execute( + "MATCH (a:native_people)-[r:native_people_arrow_knows]->(b:native_people) " + "RETURN a.name, b.name, r.weight ORDER BY a.id, b.id" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + ["Alice", "Bob", 10], + ["Alice", "Carol", 20], + ["Bob", "Carol", 30], + ] + + result = conn.execute( + "MATCH (:native_people)-[r:native_people_arrow_knows]->(:native_people) " + "RETURN COUNT(*), SUM(r.weight)" + ) + assert result.get_next() == [3, 60] + assert not result.has_next() + + result = conn.execute( + "MATCH (a:native_people)-[r:native_people_arrow_knows]->(b:native_people) " + "WHERE r.weight >= 20 " + "RETURN a.name, b.name, r.weight ORDER BY r.weight" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + ["Alice", "Carol", 20], + ["Bob", "Carol", 30], + ] + + conn.drop_arrow_table("native_people_arrow_knows")