Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rationai-mlkit"
version = "0.4.0"
version = "0.4.1"
description = ""
authors = [
{ name = "Matěj Pekár", email = "matejpekar@mail.muni.cz" },
Expand Down
134 changes: 84 additions & 50 deletions rationai/mlkit/data/datasets/meta_tiled_slides.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import TypeVar, cast
from typing import Any, TypeVar

import pyarrow.compute as pc
import numpy as np
import pyarrow as pa
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets, load_dataset
from mlflow.artifacts import download_artifacts
Expand All @@ -31,71 +32,88 @@ def __init__(
paths: Iterable[Path | str] | None = None,
uris: Iterable[str] | None = None,
slides_and_tiles: tuple[HFDataset, HFDataset] | None = None,
hf_kwargs: dict[str, Any] | None = None,
) -> None:
"""Load slides and tiles from MLFlow artifacts.

Args:
paths: List of directories to load slides and tiles from. Each
directory must include two files: `slides.parquet` and tiles.parquet`.
directory must include either single files (`slides.parquet`
and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`)
containing chunked Parquet files.
uris: List of MLFlow artifact URIs pointing to folders containing
`slides.parquet` and `tiles.parquet`.
either single files (`slides.parquet` and `tiles.parquet`) or
subdirectories (`slides/` and `tiles/`) containing chunked
Parquet files.
slides_and_tiles: Tuple containing the slides and tiles Datasets.
hf_kwargs: Additional keyword arguments to pass to HuggingFace's
`load_dataset` function. Defaults to `{"path": "parquet", "split": "train"}`.
"""
assert paths or uris or slides_and_tiles, (
"At least one of paths, uris or slides_and_tiles must be provided."
)

slides, tiles = self.load_slides_and_tiles(paths or [], uris or [])
if hf_kwargs is None:
hf_kwargs = {"path": "parquet", "split": "train"}

slides, tiles = self.load_slides_and_tiles(paths or [], uris or [], hf_kwargs)

if slides_and_tiles is not None:
slides = concatenate_datasets([slides, slides_and_tiles[0]])
tiles = concatenate_datasets([tiles, slides_and_tiles[1]])

self.slides = slides
self.tiles = tiles.sort("slide_id")
self.tiles = tiles
self._slide_id_to_indices = self._build_tile_index(self.tiles)
Comment thread
JakubPekar marked this conversation as resolved.

super().__init__(self.generate_datasets())

@staticmethod
def _build_tile_index(tiles: HFDataset) -> dict[str, range]:
def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]:
"""Creates a fast lookup table for slide indices.

This function builds a mapping from `slide_id` to the range of indices in the
`tiles` dataset that correspond to that slide. It assumes that the `tiles` dataset
is sorted by `slide_id`, which allows for efficient retrieval of tile indices
for each slide without needing to scan the entire dataset for each slide.
This function builds a mapping from `slide_id` to the list of indices in the
`tiles` dataset that correspond to that slide.

Args:
tiles: A dataset containing a `slide_id` column, sorted by `slide_id`.
tiles: A dataset containing a `slide_id` column.

Returns:
A dictionary mapping each `slide_id` to a range of indices in the `tiles` dataset.
A dictionary mapping each `slide_id` to a list of indices in the `tiles` dataset.
"""
if len(tiles) == 0:
return {}

# Get the underlying Arrow table (zero-copy)
table = tiles.data.table
# Since it's sorted, we only care about where 'slide_id' changes.
slide_ids = table.column("slide_id")
# 1. Grab the column directly from the underlying PyArrow Table
slide_ids = tiles.data.column("slide_id")
num_rows = len(slide_ids)
Comment thread
JakubPekar marked this conversation as resolved.

# 2. Handle the "Large" type conversion
current_type = slide_ids.type
if pa.types.is_string(current_type):
slide_ids = slide_ids.cast(pa.large_string())
elif pa.types.is_binary(current_type):
slide_ids = slide_ids.cast(pa.large_binary())
Comment thread
JakubPekar marked this conversation as resolved.

# Since the dataset is sorted by 'slide_id', we can use
# run-end encoding to find group boundaries efficiently.
run_ends = pc.run_end_encode(slide_ids.combine_chunks()) # pyright: ignore[reportAttributeAccessIssue]
# 3. Generate sequential row indices
# np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead
row_indices = pa.array(np.arange(num_rows, dtype=np.int64))

values = run_ends.values
ends = run_ends.run_ends
# 4. Combine them into a lightweight PyArrow Table
table = pa.Table.from_arrays(
[slide_ids, row_indices], names=["slide_id", "idx"]
)

index_map = {}
current_offset = 0
# 5. Perform the native Arrow groupby and aggregate
# The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar
grouped = table.group_by("slide_id").aggregate([("idx", "list")])

for sid, end in zip(values, ends, strict=True):
end_py = end.as_py()
index_map[sid.as_py()] = range(current_offset, end_py)
current_offset = end_py
# 6. Extract keys to Python, but KEEP values as PyArrow ListScalars
keys = grouped.column("slide_id").to_numpy()
values_array = grouped.column("idx_list")

return index_map
# Map the string key to the PyArrow ListScalar
return {key: values_array[i] for i, key in enumerate(keys)}

@abstractmethod
def generate_datasets(self) -> Iterable[Dataset[T]]:
Expand All @@ -116,7 +134,7 @@ def generate_datasets(self) -> Iterable[Dataset[T]]:
```
"""

def filter_tiles_by_slide(self, slide_id: str) -> HFDataset:
def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset:
"""Returns a view of the dataset using a slice or indices.

This function creates a view of the `self.tiles` dataset that contains only
Expand All @@ -130,20 +148,28 @@ def filter_tiles_by_slide(self, slide_id: str) -> HFDataset:
Returns:
A view of the tiles dataset containing only the tiles for the specified slide.
"""
tile_range = self._slide_id_to_indices.get(slide_id, range(0))
return self.tiles.select(tile_range)
tile_indices = self._slide_id_to_indices.get(
slide_id, pa.scalar([], type=pa.list_(pa.int64()))
)
return self.tiles.select(tile_indices.values.to_numpy())

@staticmethod
def load_slides_and_tiles(
paths: Iterable[str | Path], uris: Iterable[str]
paths: Iterable[str | Path], uris: Iterable[str], hf_kwargs: dict[str, Any]
) -> tuple[HFDataset, HFDataset]:
"""Load slides and tiles parquets from local storage and MLFlow artifacts.

Args:
paths: List of directories to load slides and tiles from. Each
directory must include two files: `slides.parquet` and tiles.parquet`.
directory must include either single files (`slides.parquet`
and `tiles.parquet`) or subdirectories (`slides/` and `tiles/`)
containing chunked Parquet files.
uris: List of MLFlow artifact URIs pointing to folders containing
`slides.parquet` and `tiles.parquet`.
either single files (`slides.parquet` and `tiles.parquet`) or
subdirectories (`slides/` and `tiles/`) containing chunked
Parquet files.
hf_kwargs: Additional keyword arguments to pass to HuggingFace's
`load_dataset` function.

Raises:
FileNotFoundError: If the data cannot be loaded from the specified URIs.
Expand All @@ -159,27 +185,35 @@ def load_slides_and_tiles(

search_dirs = [Path(p) for p in (*paths, *artifacts_paths)]

# Extract existing file paths
slide_files = [
str(s) for p in search_dirs if (s := p / "slides.parquet").exists()
]
tile_files = [
str(t) for p in search_dirs if (t := p / "tiles.parquet").exists()
]

# Handle empty datasets
if not (slide_files and tile_files):
if not len(search_dirs):
return HFDataset.from_dict({}), HFDataset.from_dict({})

def resolve_search_path(partition: str) -> list[dict[str, str]]:
return [
{"data_dir": str(path / partition)}
if (path / partition).is_dir()
else {"data_files": str(path / f"{partition}.parquet")}
for path in search_dirs
]
Comment thread
JakubPekar marked this conversation as resolved.
Comment thread
JakubPekar marked this conversation as resolved.

try:
# Load datasets with memory mapping (lazy)
loader_kwargs = {"path": "parquet", "split": "train"}
slides_ds = concatenate_datasets(
[
load_dataset(**hf_kwargs, **datasource)
for datasource in resolve_search_path("slides")
]
Comment thread
JakubPekar marked this conversation as resolved.
)

slides_ds = load_dataset(**loader_kwargs, data_files=slide_files) # pyright: ignore[reportArgumentType, reportCallIssue]
tiles_ds = load_dataset(**loader_kwargs, data_files=tile_files) # pyright: ignore[reportArgumentType, reportCallIssue]
tiles_ds = concatenate_datasets(
[
load_dataset(**hf_kwargs, **datasource)
for datasource in resolve_search_path("tiles")
]
)
Comment thread
JakubPekar marked this conversation as resolved.

return cast("HFDataset", slides_ds), cast("HFDataset", tiles_ds)
return slides_ds, tiles_ds

except Exception as e:
msg = f"Failed to load Parquet files. Found {len(slide_files)} slides and {len(tile_files)} tiles."
msg = "Failed to load Parquet files."
Comment thread
JakubPekar marked this conversation as resolved.
raise FileNotFoundError(msg) from e
Comment thread
JakubPekar marked this conversation as resolved.
Comment thread
JakubPekar marked this conversation as resolved.
Comment thread
JakubPekar marked this conversation as resolved.
Loading