Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rationai/mlkit/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rationai.mlkit.data.datasets.meta_tiled_slides import MetaTiledSlides
from rationai.mlkit.data.datasets.openslide_tiles_dataset import OpenSlideTilesDataset
from rationai.mlkit.data.datasets.slides_tiles_loader import SlidesTilesLoader


__all__ = ["MetaTiledSlides", "OpenSlideTilesDataset"]
__all__ = ["MetaTiledSlides", "OpenSlideTilesDataset", "SlidesTilesLoader"]
165 changes: 19 additions & 146 deletions rationai/mlkit/data/datasets/meta_tiled_slides.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, TypeVar

import numpy as np
import pyarrow as pa
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets, load_dataset
from mlflow.artifacts import download_artifacts
from torch.utils.data import ConcatDataset, Dataset

from rationai.mlkit.data.datasets.slides_tiles_loader import SlidesTilesLoader


T = TypeVar("T", covariant=True)

Expand Down Expand Up @@ -49,71 +46,31 @@ def __init__(
hf_kwargs: Additional keyword arguments to pass to HuggingFace's
`load_dataset` function. Defaults to `{"path": "parquet", "split": "train"}`.
"""
assert paths or uris or slides_and_tiles, (
"At least one of paths, uris or slides_and_tiles must be provided."
self._meta = SlidesTilesLoader(
paths=paths,
uris=uris,
slides_and_tiles=slides_and_tiles,
hf_kwargs=hf_kwargs,
)

if hf_kwargs is None:
hf_kwargs = {"path": "parquet", "split": "train"}

slides, tiles = self.load_slides_and_tiles(paths or [], uris or [], hf_kwargs)

if slides_and_tiles is not None:
slides = concatenate_datasets([slides, slides_and_tiles[0]])
tiles = concatenate_datasets([tiles, slides_and_tiles[1]])

self.slides = slides
self.tiles = tiles
self._slide_id_to_indices = self._build_tile_index(self.tiles)

self.slides = self._meta.slides
self.tiles = self._meta.tiles
super().__init__(self.generate_datasets())
Comment thread
Adames4 marked this conversation as resolved.

@staticmethod
def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]:
"""Creates a fast lookup table for slide indices.
def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset:
"""Returns a view of the dataset using a slice or indices.

This function builds a mapping from `slide_id` to the list of indices in the
`tiles` dataset that correspond to that slide.
This function creates a view of the `self.tiles` dataset that contains only
the tiles belonging to the specified slide. It uses the precomputed
`_slide_id_to_indices` mapping to efficiently retrieve the relevant tiles
without copying data.

Args:
tiles: A dataset containing a `slide_id` column.
slide_id: The ID of the slide to filter tiles.

Returns:
A dictionary mapping each `slide_id` to a list of indices in the `tiles` dataset.
A view of the tiles dataset containing only the tiles for the specified slide.
"""
if len(tiles) == 0:
return {}

# 1. Grab the column directly from the underlying PyArrow Table
slide_ids = tiles.data.column("slide_id")
num_rows = len(slide_ids)

# 2. Handle the "Large" type conversion
current_type = slide_ids.type
if pa.types.is_string(current_type):
slide_ids = slide_ids.cast(pa.large_string())
elif pa.types.is_binary(current_type):
slide_ids = slide_ids.cast(pa.large_binary())

# 3. Generate sequential row indices
# np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead
row_indices = pa.array(np.arange(num_rows, dtype=np.int64))

# 4. Combine them into a lightweight PyArrow Table
table = pa.Table.from_arrays(
[slide_ids, row_indices], names=["slide_id", "idx"]
)

# 5. Perform the native Arrow groupby and aggregate
# The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar
grouped = table.group_by("slide_id").aggregate([("idx", "list")])

# 6. Extract keys to Python, but KEEP values as PyArrow ListScalars
keys = grouped.column("slide_id").to_numpy()
values_array = grouped.column("idx_list")

# Map the string key to the PyArrow ListScalar
return {key: values_array[i] for i, key in enumerate(keys)}
return self._meta.filter_tiles_by_slide(slide_id)

@abstractmethod
def generate_datasets(self) -> Iterable[Dataset[T]]:
Expand All @@ -127,93 +84,9 @@ def generate_datasets(self) -> Iterable[Dataset[T]]:
level=slide["level"],
tile_extent_x=slide["tile_extent_x"],
tile_extent_y=slide["tile_extent_y"],
tiles=self.filter_tiles_by_slide(slide.id),
tiles=self.filter_tiles_by_slide(slide["id"]),
)
for slide in self.slides
)
```
"""

def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset:
"""Returns a view of the dataset using a slice or indices.

This function creates a view of the `self.tiles` dataset that contains only
the tiles belonging to the specified slide. It uses the precomputed
`_slide_id_to_indices` mapping to efficiently retrieve the relevant tiles
without copying data.

Args:
slide_id: The ID of the slide to filter tiles.

Returns:
A view of the tiles dataset containing only the tiles for the specified slide.
"""
tile_indices = self._slide_id_to_indices.get(
slide_id, pa.scalar([], type=pa.list_(pa.int64()))
)
return self.tiles.select(tile_indices.values.to_numpy())

@staticmethod
def load_slides_and_tiles(
paths: Iterable[str | Path], uris: Iterable[str], hf_kwargs: dict[str, Any]
) -> tuple[HFDataset, HFDataset]:
"""Load slides and tiles parquets from local storage and MLFlow artifacts.

Args:
paths: List of directories to load slides and tiles from. Each
directory must include either single files (`slides.parquet`
and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`)
containing chunked Parquet files.
uris: List of MLFlow artifact URIs pointing to folders containing
either single files (`slides.parquet` and `tiles.parquet`) or
subdirectories (`slides/` and `tiles/`) containing chunked
Parquet files.
hf_kwargs: Additional keyword arguments to pass to HuggingFace's
`load_dataset` function.

Raises:
FileNotFoundError: If the data cannot be loaded from the specified URIs.

Returns:
A tuple containing the slides and tiles Datasets.
"""
# Parallelize MLFlow downloads (I/O Bound)
with ThreadPoolExecutor() as executor:
artifacts_paths = list(
executor.map(lambda uri: download_artifacts(artifact_uri=uri), uris)
)

search_dirs = [Path(p) for p in (*paths, *artifacts_paths)]

# Handle empty datasets
if not len(search_dirs):
return HFDataset.from_dict({}), HFDataset.from_dict({})

def resolve_search_path(partition: str) -> list[dict[str, str]]:
return [
{"data_dir": str(path / partition)}
if (path / partition).is_dir()
else {"data_files": str(path / f"{partition}.parquet")}
for path in search_dirs
]

try:
slides_ds = concatenate_datasets(
[
load_dataset(**hf_kwargs, **datasource)
for datasource in resolve_search_path("slides")
]
)

tiles_ds = concatenate_datasets(
[
load_dataset(**hf_kwargs, **datasource)
for datasource in resolve_search_path("tiles")
]
)

return slides_ds, tiles_ds

except Exception as e:
msg = "Failed to load Parquet files."
raise FileNotFoundError(msg) from e
Loading
Loading