diff --git a/rationai/mlkit/data/datasets/__init__.py b/rationai/mlkit/data/datasets/__init__.py index 453129e..bf56e5d 100644 --- a/rationai/mlkit/data/datasets/__init__.py +++ b/rationai/mlkit/data/datasets/__init__.py @@ -1,5 +1,6 @@ from rationai.mlkit.data.datasets.meta_tiled_slides import MetaTiledSlides from rationai.mlkit.data.datasets.openslide_tiles_dataset import OpenSlideTilesDataset +from rationai.mlkit.data.datasets.slides_tiles_loader import SlidesTilesLoader -__all__ = ["MetaTiledSlides", "OpenSlideTilesDataset"] +__all__ = ["MetaTiledSlides", "OpenSlideTilesDataset", "SlidesTilesLoader"] diff --git a/rationai/mlkit/data/datasets/meta_tiled_slides.py b/rationai/mlkit/data/datasets/meta_tiled_slides.py index 8437e41..ff61c79 100644 --- a/rationai/mlkit/data/datasets/meta_tiled_slides.py +++ b/rationai/mlkit/data/datasets/meta_tiled_slides.py @@ -1,16 +1,13 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, TypeVar -import numpy as np -import pyarrow as pa from datasets import Dataset as HFDataset -from datasets import concatenate_datasets, load_dataset -from mlflow.artifacts import download_artifacts from torch.utils.data import ConcatDataset, Dataset +from rationai.mlkit.data.datasets.slides_tiles_loader import SlidesTilesLoader + T = TypeVar("T", covariant=True) @@ -49,71 +46,31 @@ def __init__( hf_kwargs: Additional keyword arguments to pass to HuggingFace's `load_dataset` function. Defaults to `{"path": "parquet", "split": "train"}`. """ - assert paths or uris or slides_and_tiles, ( - "At least one of paths, uris or slides_and_tiles must be provided." + self._meta = SlidesTilesLoader( + paths=paths, + uris=uris, + slides_and_tiles=slides_and_tiles, + hf_kwargs=hf_kwargs, ) - - if hf_kwargs is None: - hf_kwargs = {"path": "parquet", "split": "train"} - - slides, tiles = self.load_slides_and_tiles(paths or [], uris or [], hf_kwargs) - - if slides_and_tiles is not None: - slides = concatenate_datasets([slides, slides_and_tiles[0]]) - tiles = concatenate_datasets([tiles, slides_and_tiles[1]]) - - self.slides = slides - self.tiles = tiles - self._slide_id_to_indices = self._build_tile_index(self.tiles) - + self.slides = self._meta.slides + self.tiles = self._meta.tiles super().__init__(self.generate_datasets()) - @staticmethod - def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]: - """Creates a fast lookup table for slide indices. + def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: + """Returns a view of the dataset using a slice or indices. - This function builds a mapping from `slide_id` to the list of indices in the - `tiles` dataset that correspond to that slide. + This function creates a view of the `self.tiles` dataset that contains only + the tiles belonging to the specified slide. It uses the precomputed + `_slide_id_to_indices` mapping to efficiently retrieve the relevant tiles + without copying data. Args: - tiles: A dataset containing a `slide_id` column. + slide_id: The ID of the slide to filter tiles. Returns: - A dictionary mapping each `slide_id` to a list of indices in the `tiles` dataset. + A view of the tiles dataset containing only the tiles for the specified slide. """ - if len(tiles) == 0: - return {} - - # 1. Grab the column directly from the underlying PyArrow Table - slide_ids = tiles.data.column("slide_id") - num_rows = len(slide_ids) - - # 2. Handle the "Large" type conversion - current_type = slide_ids.type - if pa.types.is_string(current_type): - slide_ids = slide_ids.cast(pa.large_string()) - elif pa.types.is_binary(current_type): - slide_ids = slide_ids.cast(pa.large_binary()) - - # 3. Generate sequential row indices - # np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead - row_indices = pa.array(np.arange(num_rows, dtype=np.int64)) - - # 4. Combine them into a lightweight PyArrow Table - table = pa.Table.from_arrays( - [slide_ids, row_indices], names=["slide_id", "idx"] - ) - - # 5. Perform the native Arrow groupby and aggregate - # The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar - grouped = table.group_by("slide_id").aggregate([("idx", "list")]) - - # 6. Extract keys to Python, but KEEP values as PyArrow ListScalars - keys = grouped.column("slide_id").to_numpy() - values_array = grouped.column("idx_list") - - # Map the string key to the PyArrow ListScalar - return {key: values_array[i] for i, key in enumerate(keys)} + return self._meta.filter_tiles_by_slide(slide_id) @abstractmethod def generate_datasets(self) -> Iterable[Dataset[T]]: @@ -127,93 +84,9 @@ def generate_datasets(self) -> Iterable[Dataset[T]]: level=slide["level"], tile_extent_x=slide["tile_extent_x"], tile_extent_y=slide["tile_extent_y"], - tiles=self.filter_tiles_by_slide(slide.id), + tiles=self.filter_tiles_by_slide(slide["id"]), ) for slide in self.slides ) ``` """ - - def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: - """Returns a view of the dataset using a slice or indices. - - This function creates a view of the `self.tiles` dataset that contains only - the tiles belonging to the specified slide. It uses the precomputed - `_slide_id_to_indices` mapping to efficiently retrieve the relevant tiles - without copying data. - - Args: - slide_id: The ID of the slide to filter tiles. - - Returns: - A view of the tiles dataset containing only the tiles for the specified slide. - """ - tile_indices = self._slide_id_to_indices.get( - slide_id, pa.scalar([], type=pa.list_(pa.int64())) - ) - return self.tiles.select(tile_indices.values.to_numpy()) - - @staticmethod - def load_slides_and_tiles( - paths: Iterable[str | Path], uris: Iterable[str], hf_kwargs: dict[str, Any] - ) -> tuple[HFDataset, HFDataset]: - """Load slides and tiles parquets from local storage and MLFlow artifacts. - - Args: - paths: List of directories to load slides and tiles from. Each - directory must include either single files (`slides.parquet` - and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`) - containing chunked Parquet files. - uris: List of MLFlow artifact URIs pointing to folders containing - either single files (`slides.parquet` and `tiles.parquet`) or - subdirectories (`slides/` and `tiles/`) containing chunked - Parquet files. - hf_kwargs: Additional keyword arguments to pass to HuggingFace's - `load_dataset` function. - - Raises: - FileNotFoundError: If the data cannot be loaded from the specified URIs. - - Returns: - A tuple containing the slides and tiles Datasets. - """ - # Parallelize MLFlow downloads (I/O Bound) - with ThreadPoolExecutor() as executor: - artifacts_paths = list( - executor.map(lambda uri: download_artifacts(artifact_uri=uri), uris) - ) - - search_dirs = [Path(p) for p in (*paths, *artifacts_paths)] - - # Handle empty datasets - if not len(search_dirs): - return HFDataset.from_dict({}), HFDataset.from_dict({}) - - def resolve_search_path(partition: str) -> list[dict[str, str]]: - return [ - {"data_dir": str(path / partition)} - if (path / partition).is_dir() - else {"data_files": str(path / f"{partition}.parquet")} - for path in search_dirs - ] - - try: - slides_ds = concatenate_datasets( - [ - load_dataset(**hf_kwargs, **datasource) - for datasource in resolve_search_path("slides") - ] - ) - - tiles_ds = concatenate_datasets( - [ - load_dataset(**hf_kwargs, **datasource) - for datasource in resolve_search_path("tiles") - ] - ) - - return slides_ds, tiles_ds - - except Exception as e: - msg = "Failed to load Parquet files." - raise FileNotFoundError(msg) from e diff --git a/rationai/mlkit/data/datasets/slides_tiles_loader.py b/rationai/mlkit/data/datasets/slides_tiles_loader.py new file mode 100644 index 0000000..7359965 --- /dev/null +++ b/rationai/mlkit/data/datasets/slides_tiles_loader.py @@ -0,0 +1,192 @@ +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +import numpy as np +import pyarrow as pa +from datasets import Dataset as HFDataset +from datasets import concatenate_datasets, load_dataset +from mlflow.artifacts import download_artifacts + + +class SlidesTilesLoader: + """Loads and concatenates slides/tiles metadata.""" + + def __init__( + self, + *, + paths: Iterable[Path | str] | None = None, + uris: Iterable[str] | None = None, + slides_and_tiles: tuple[HFDataset, HFDataset] | None = None, + hf_kwargs: dict[str, Any] | None = None, + ) -> None: + """Load slides and tiles from local paths, MLFlow URIs, or preloaded datasets. + + Args: + paths: List of directories to load slides and tiles from. Each + directory must include either single files (`slides.parquet` + and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`) + containing chunked Parquet files. + uris: List of MLFlow artifact URIs pointing to folders containing + either single files (`slides.parquet` and `tiles.parquet`) or + subdirectories (`slides/` and `tiles/`) containing chunked + Parquet files. + slides_and_tiles: Tuple containing the slides and tiles Datasets. + hf_kwargs: Additional keyword arguments to pass to HuggingFace's + `load_dataset` function. Defaults to `{"path": "parquet", "split": "train"}`. + """ + if not (paths or uris or slides_and_tiles): + raise ValueError( + "At least one of paths, uris or slides_and_tiles must be provided." + ) + + if hf_kwargs is None: + hf_kwargs = {"path": "parquet", "split": "train"} + + slides = [] + tiles = [] + + if paths or uris: + s, t = self.load_slides_and_tiles(paths or [], uris or [], hf_kwargs) + slides.append(s) + tiles.append(t) + + if slides_and_tiles is not None: + slides.append(slides_and_tiles[0]) + tiles.append(slides_and_tiles[1]) + + self.slides = concatenate_datasets(slides) if len(slides) > 1 else slides[0] + self.tiles = concatenate_datasets(tiles) if len(tiles) > 1 else tiles[0] + self._slide_id_to_indices = self._build_tile_index(self.tiles) + + @staticmethod + def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]: + """Creates a fast lookup table for slide indices. + + This function builds a mapping from `slide_id` to the list of indices in the + `tiles` dataset that correspond to that slide. + + Args: + tiles: A dataset containing a `slide_id` column. + + Returns: + A dictionary mapping each `slide_id` to a list of indices in the `tiles` dataset. + """ + if len(tiles) == 0: + return {} + + # 1. Grab the column directly from the underlying PyArrow Table + slide_ids = tiles.data.column("slide_id") + num_rows = len(slide_ids) + + # 2. Handle the "Large" type conversion + current_type = slide_ids.type + if pa.types.is_string(current_type): + slide_ids = slide_ids.cast(pa.large_string()) + elif pa.types.is_binary(current_type): + slide_ids = slide_ids.cast(pa.large_binary()) + + # 3. Generate sequential row indices + # np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead + row_indices = pa.array(np.arange(num_rows, dtype=np.int64)) + + # 4. Combine them into a lightweight PyArrow Table + table = pa.Table.from_arrays( + [slide_ids, row_indices], names=["slide_id", "idx"] + ) + + # 5. Perform the native Arrow groupby and aggregate + # The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar + grouped = table.group_by("slide_id").aggregate([("idx", "list")]) + + # 6. Extract keys to Python, but KEEP values as PyArrow ListScalars + keys = grouped.column("slide_id").to_numpy() + values_array = grouped.column("idx_list") + + # Map the string key to the PyArrow ListScalar + return {key: values_array[i] for i, key in enumerate(keys)} + + def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: + """Returns a view of the dataset using a slice or indices. + + This function creates a view of the `self.tiles` dataset that contains only + the tiles belonging to the specified slide. It uses the precomputed + `_slide_id_to_indices` mapping to efficiently retrieve the relevant tiles + without copying data. + + Args: + slide_id: The ID of the slide to filter tiles. + + Returns: + A view of the tiles dataset containing only the tiles for the specified slide. + """ + tile_indices = self._slide_id_to_indices.get( + slide_id, pa.scalar([], type=pa.list_(pa.int64())) + ) + return self.tiles.select(tile_indices.values.to_numpy()) + + @staticmethod + def load_slides_and_tiles( + paths: Iterable[str | Path], uris: Iterable[str], hf_kwargs: dict[str, Any] + ) -> tuple[HFDataset, HFDataset]: + """Load slides and tiles parquets from local storage and MLFlow artifacts. + + Args: + paths: List of directories to load slides and tiles from. Each + directory must include either single files (`slides.parquet` + and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`) + containing chunked Parquet files. + uris: List of MLFlow artifact URIs pointing to folders containing + either single files (`slides.parquet` and `tiles.parquet`) or + subdirectories (`slides/` and `tiles/`) containing chunked + Parquet files. + hf_kwargs: Additional keyword arguments to pass to HuggingFace's + `load_dataset` function. + + Raises: + RuntimeError: If the data cannot be loaded from the specified URIs. + + Returns: + A tuple containing the slides and tiles Datasets. + """ + # Parallelize MLFlow downloads (I/O Bound) + with ThreadPoolExecutor() as executor: + artifacts_paths = list( + executor.map(lambda uri: download_artifacts(artifact_uri=uri), uris) + ) + + search_dirs = [Path(p) for p in (*paths, *artifacts_paths)] + + # Handle empty datasets + if not len(search_dirs): + return HFDataset.from_dict({}), HFDataset.from_dict({}) + + def resolve_search_path(partition: str) -> list[dict[str, str]]: + return [ + {"data_dir": str(path / partition)} + if (path / partition).is_dir() + else {"data_files": str(path / f"{partition}.parquet")} + for path in search_dirs + ] + + try: + slides_ds = concatenate_datasets( + [ + load_dataset(**hf_kwargs, **datasource) + for datasource in resolve_search_path("slides") + ] + ) + + tiles_ds = concatenate_datasets( + [ + load_dataset(**hf_kwargs, **datasource) + for datasource in resolve_search_path("tiles") + ] + ) + + return slides_ds, tiles_ds + + except Exception as e: + msg = "Failed to load Parquet files." + raise RuntimeError(msg) from e