diff --git a/pyproject.toml b/pyproject.toml index 9c9a4b3..8124b29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rationai-mlkit" -version = "0.4.0" +version = "0.4.1" description = "" authors = [ { name = "Matěj Pekár", email = "matejpekar@mail.muni.cz" }, diff --git a/rationai/mlkit/data/datasets/meta_tiled_slides.py b/rationai/mlkit/data/datasets/meta_tiled_slides.py index c65a6b0..8437e41 100644 --- a/rationai/mlkit/data/datasets/meta_tiled_slides.py +++ b/rationai/mlkit/data/datasets/meta_tiled_slides.py @@ -2,9 +2,10 @@ from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TypeVar, cast +from typing import Any, TypeVar -import pyarrow.compute as pc +import numpy as np +import pyarrow as pa from datasets import Dataset as HFDataset from datasets import concatenate_datasets, load_dataset from mlflow.artifacts import download_artifacts @@ -31,71 +32,88 @@ def __init__( paths: Iterable[Path | str] | None = None, uris: Iterable[str] | None = None, slides_and_tiles: tuple[HFDataset, HFDataset] | None = None, + hf_kwargs: dict[str, Any] | None = None, ) -> None: """Load slides and tiles from MLFlow artifacts. Args: paths: List of directories to load slides and tiles from. Each - directory must include two files: `slides.parquet` and tiles.parquet`. + directory must include either single files (`slides.parquet` + and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`) + containing chunked Parquet files. uris: List of MLFlow artifact URIs pointing to folders containing - `slides.parquet` and `tiles.parquet`. + either single files (`slides.parquet` and `tiles.parquet`) or + subdirectories (`slides/` and `tiles/`) containing chunked + Parquet files. slides_and_tiles: Tuple containing the slides and tiles Datasets. + hf_kwargs: Additional keyword arguments to pass to HuggingFace's + `load_dataset` function. Defaults to `{"path": "parquet", "split": "train"}`. """ assert paths or uris or slides_and_tiles, ( "At least one of paths, uris or slides_and_tiles must be provided." ) - slides, tiles = self.load_slides_and_tiles(paths or [], uris or []) + if hf_kwargs is None: + hf_kwargs = {"path": "parquet", "split": "train"} + + slides, tiles = self.load_slides_and_tiles(paths or [], uris or [], hf_kwargs) if slides_and_tiles is not None: slides = concatenate_datasets([slides, slides_and_tiles[0]]) tiles = concatenate_datasets([tiles, slides_and_tiles[1]]) self.slides = slides - self.tiles = tiles.sort("slide_id") + self.tiles = tiles self._slide_id_to_indices = self._build_tile_index(self.tiles) super().__init__(self.generate_datasets()) @staticmethod - def _build_tile_index(tiles: HFDataset) -> dict[str, range]: + def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]: """Creates a fast lookup table for slide indices. - This function builds a mapping from `slide_id` to the range of indices in the - `tiles` dataset that correspond to that slide. It assumes that the `tiles` dataset - is sorted by `slide_id`, which allows for efficient retrieval of tile indices - for each slide without needing to scan the entire dataset for each slide. + This function builds a mapping from `slide_id` to the list of indices in the + `tiles` dataset that correspond to that slide. Args: - tiles: A dataset containing a `slide_id` column, sorted by `slide_id`. + tiles: A dataset containing a `slide_id` column. Returns: - A dictionary mapping each `slide_id` to a range of indices in the `tiles` dataset. + A dictionary mapping each `slide_id` to a list of indices in the `tiles` dataset. """ if len(tiles) == 0: return {} - # Get the underlying Arrow table (zero-copy) - table = tiles.data.table - # Since it's sorted, we only care about where 'slide_id' changes. - slide_ids = table.column("slide_id") + # 1. Grab the column directly from the underlying PyArrow Table + slide_ids = tiles.data.column("slide_id") + num_rows = len(slide_ids) + + # 2. Handle the "Large" type conversion + current_type = slide_ids.type + if pa.types.is_string(current_type): + slide_ids = slide_ids.cast(pa.large_string()) + elif pa.types.is_binary(current_type): + slide_ids = slide_ids.cast(pa.large_binary()) - # Since the dataset is sorted by 'slide_id', we can use - # run-end encoding to find group boundaries efficiently. - run_ends = pc.run_end_encode(slide_ids.combine_chunks()) # pyright: ignore[reportAttributeAccessIssue] + # 3. Generate sequential row indices + # np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead + row_indices = pa.array(np.arange(num_rows, dtype=np.int64)) - values = run_ends.values - ends = run_ends.run_ends + # 4. Combine them into a lightweight PyArrow Table + table = pa.Table.from_arrays( + [slide_ids, row_indices], names=["slide_id", "idx"] + ) - index_map = {} - current_offset = 0 + # 5. Perform the native Arrow groupby and aggregate + # The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar + grouped = table.group_by("slide_id").aggregate([("idx", "list")]) - for sid, end in zip(values, ends, strict=True): - end_py = end.as_py() - index_map[sid.as_py()] = range(current_offset, end_py) - current_offset = end_py + # 6. Extract keys to Python, but KEEP values as PyArrow ListScalars + keys = grouped.column("slide_id").to_numpy() + values_array = grouped.column("idx_list") - return index_map + # Map the string key to the PyArrow ListScalar + return {key: values_array[i] for i, key in enumerate(keys)} @abstractmethod def generate_datasets(self) -> Iterable[Dataset[T]]: @@ -116,7 +134,7 @@ def generate_datasets(self) -> Iterable[Dataset[T]]: ``` """ - def filter_tiles_by_slide(self, slide_id: str) -> HFDataset: + def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: """Returns a view of the dataset using a slice or indices. This function creates a view of the `self.tiles` dataset that contains only @@ -130,20 +148,28 @@ def filter_tiles_by_slide(self, slide_id: str) -> HFDataset: Returns: A view of the tiles dataset containing only the tiles for the specified slide. """ - tile_range = self._slide_id_to_indices.get(slide_id, range(0)) - return self.tiles.select(tile_range) + tile_indices = self._slide_id_to_indices.get( + slide_id, pa.scalar([], type=pa.list_(pa.int64())) + ) + return self.tiles.select(tile_indices.values.to_numpy()) @staticmethod def load_slides_and_tiles( - paths: Iterable[str | Path], uris: Iterable[str] + paths: Iterable[str | Path], uris: Iterable[str], hf_kwargs: dict[str, Any] ) -> tuple[HFDataset, HFDataset]: """Load slides and tiles parquets from local storage and MLFlow artifacts. Args: paths: List of directories to load slides and tiles from. Each - directory must include two files: `slides.parquet` and tiles.parquet`. + directory must include either single files (`slides.parquet` + and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`) + containing chunked Parquet files. uris: List of MLFlow artifact URIs pointing to folders containing - `slides.parquet` and `tiles.parquet`. + either single files (`slides.parquet` and `tiles.parquet`) or + subdirectories (`slides/` and `tiles/`) containing chunked + Parquet files. + hf_kwargs: Additional keyword arguments to pass to HuggingFace's + `load_dataset` function. Raises: FileNotFoundError: If the data cannot be loaded from the specified URIs. @@ -159,27 +185,35 @@ def load_slides_and_tiles( search_dirs = [Path(p) for p in (*paths, *artifacts_paths)] - # Extract existing file paths - slide_files = [ - str(s) for p in search_dirs if (s := p / "slides.parquet").exists() - ] - tile_files = [ - str(t) for p in search_dirs if (t := p / "tiles.parquet").exists() - ] - # Handle empty datasets - if not (slide_files and tile_files): + if not len(search_dirs): return HFDataset.from_dict({}), HFDataset.from_dict({}) + def resolve_search_path(partition: str) -> list[dict[str, str]]: + return [ + {"data_dir": str(path / partition)} + if (path / partition).is_dir() + else {"data_files": str(path / f"{partition}.parquet")} + for path in search_dirs + ] + try: - # Load datasets with memory mapping (lazy) - loader_kwargs = {"path": "parquet", "split": "train"} + slides_ds = concatenate_datasets( + [ + load_dataset(**hf_kwargs, **datasource) + for datasource in resolve_search_path("slides") + ] + ) - slides_ds = load_dataset(**loader_kwargs, data_files=slide_files) # pyright: ignore[reportArgumentType, reportCallIssue] - tiles_ds = load_dataset(**loader_kwargs, data_files=tile_files) # pyright: ignore[reportArgumentType, reportCallIssue] + tiles_ds = concatenate_datasets( + [ + load_dataset(**hf_kwargs, **datasource) + for datasource in resolve_search_path("tiles") + ] + ) - return cast("HFDataset", slides_ds), cast("HFDataset", tiles_ds) + return slides_ds, tiles_ds except Exception as e: - msg = f"Failed to load Parquet files. Found {len(slide_files)} slides and {len(tile_files)} tiles." + msg = "Failed to load Parquet files." raise FileNotFoundError(msg) from e