Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions rationai/mlkit/data/datasets/slides_tiles_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,28 @@ def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]:
if len(tiles) == 0:
return {}

# 1. Grab the column directly from the underlying PyArrow Table
slide_ids = tiles.data.column("slide_id")
num_rows = len(slide_ids)

# 2. Handle the "Large" type conversion
# group_by requires the "large" variants for string/binary columns
current_type = slide_ids.type
if pa.types.is_string(current_type):
slide_ids = slide_ids.cast(pa.large_string())
elif pa.types.is_binary(current_type):
slide_ids = slide_ids.cast(pa.large_binary())

# 3. Generate sequential row indices
# np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead
# np.arange is used here because PyArrow can wrap it with zero-copy overhead
row_indices = pa.array(np.arange(num_rows, dtype=np.int64))

# 4. Combine them into a lightweight PyArrow Table
table = pa.Table.from_arrays(
[slide_ids, row_indices], names=["slide_id", "idx"]
)

# 5. Perform the native Arrow groupby and aggregate
# The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar
# "list" aggregates all indices for a given slide_id into a single Arrow List scalar
grouped = table.group_by("slide_id").aggregate([("idx", "list")])

# 6. Extract keys to Python, but KEEP values as PyArrow ListScalars
# Keep values as PyArrow ListScalars to avoid materializing them in Python
keys = grouped.column("slide_id").to_numpy()
values_array = grouped.column("idx_list")

# Map the string key to the PyArrow ListScalar
return {key: values_array[i] for i, key in enumerate(keys)}

def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset:
Expand Down
12 changes: 5 additions & 7 deletions rationai/mlkit/data/samplers/stratified_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def __iter__(self) -> Iterator[list[int]]:
if not len(indices[group_idx]):
indices.pop(group_idx)

batch_indices = list(batch_indices)
random.shuffle(batch_indices)
yield batch_indices
batch_list = list(batch_indices)
random.shuffle(batch_list)
yield batch_list

def __len__(self) -> int:
return self._indices_size(self.data_indices).sum() // self.batch_size
Expand All @@ -89,16 +89,14 @@ class PDMStratifiedBatchSampler(StratifiedBatchSampler):

This sampler is designed to create balanced batches from a DataFrame by
stratifying samples based on a specified column.


"""

def __init__(
self,
data: pd.DataFrame,
stratify_by: None,
stratify_by: str | list[str],
batch_size: int,
**kwargs: dict[str, Any],
**kwargs: Any,
) -> None:
Comment on lines 94 to 100

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In Python type hinting, **kwargs: T specifies that all keyword arguments passed must be of type T. Specifying **kwargs: dict[str, Any] means each individual keyword argument must be a dictionary, which is likely not the intention here. It should be typed as **kwargs: Any.

Suggested change
def __init__(
self,
data: pd.DataFrame,
stratify_by: None,
stratify_by: str | list[str],
batch_size: int,
**kwargs: dict[str, Any],
) -> None:
def __init__(
self,
data: pd.DataFrame,
stratify_by: str | list[str],
batch_size: int,
**kwargs: Any,
) -> None:

"""Initializes the PDMStratifiedBatchSampler with DataFrame and batch size.

Expand Down
7 changes: 0 additions & 7 deletions rationai/mlkit/data/shard_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def shard_parquet(
AssertionError: If `rows_per_shard` or `row_group_size` are not strictly positive,
or if `rows_per_shard` is not perfectly divisible by `row_group_size`.
"""
# --- Input Validation ---
assert rows_per_shard > 0, "rows_per_shard must be greater than 0"
assert row_group_size > 0, "row_group_size must be greater than 0"

Expand All @@ -40,31 +39,25 @@ def shard_parquet(
"rows_per_shard must be divisible by row_group_size"
)

# --- Setup Output Directory ---
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# --- Read and Shard Process ---
with pq.ParquetFile(input_file) as parquet_file:
_logger.info(f"Total rows in source: {parquet_file.metadata.num_rows}")

# Initialize tracking variables
shard_idx = 0
current_shard_rows = 0
writer = None

try:
# Iterate through the source file in memory-efficient chunks (batches)
for batch in parquet_file.iter_batches(batch_size=row_group_size):
if writer is None:
out_path = output_dir / f"shard_{shard_idx:05d}.parquet"
writer = pq.ParquetWriter(out_path, batch.schema)

# Write the current batch
writer.write_batch(batch)
current_shard_rows += batch.num_rows

# Check if the current shard has reached its maximum capacity
if current_shard_rows >= rows_per_shard:
writer.close()
writer = None
Comment on lines 61 to 63

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug here: when a shard reaches its capacity (rows_per_shard), the writer is closed and set to None, but shard_idx is never incremented and current_shard_rows is never reset to 0.

This causes two major issues:

  1. The next batch will overwrite the same file (shard_00000.parquet) because shard_idx remains 0.
  2. current_shard_rows will keep accumulating and stay above rows_per_shard, causing a new file to be opened and closed on every single subsequent iteration.

We must increment shard_idx and reset current_shard_rows to 0 when closing the writer.

Suggested change
if current_shard_rows >= rows_per_shard:
writer.close()
writer = None
if current_shard_rows >= rows_per_shard:
writer.close()
writer = None
shard_idx += 1
current_shard_rows = 0

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is not true, the code is there already

Expand Down
3 changes: 1 addition & 2 deletions rationai/mlkit/lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def _log_checkpoint(self, key: str, path: str) -> None:
self.run_id, tmpdir, f"{MLFLOW_CHECKPOINT_PATH}/{key}"
)

# Ensures that MLFlow logged checkpoints are in sync with those saved by the trainer.
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
"""Scan checkpoints and log them to MLFlow if not already logged."""
"""Scan checkpoints and log them to MLFlow, keeping them in sync with the trainer."""
checkpoints = self._scan_checkpoints(checkpoint_callback)

logged_checkpoints = {
Expand Down
22 changes: 15 additions & 7 deletions rationai/mlkit/metrics/aggregated_metric_collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from collections.abc import Sequence
from typing import Any
from collections.abc import Callable, Sequence
from typing import Any, cast

from torch import Tensor
from torchmetrics import Metric, MetricCollection
Expand Down Expand Up @@ -79,11 +79,19 @@ def __init__(
aggregator: Aggregator,
prefix: str | None = None,
) -> None:
super().__init__(metrics, prefix=prefix)

self.aggregators: dict[str, Aggregator] = defaultdict(aggregator.clone)

def update( # pylint: disable=arguments-differ
super().__init__(
cast(
"Metric | MetricCollection | Sequence[Metric | MetricCollection] | dict[str, Metric | MetricCollection]",
metrics,
),
prefix=prefix,
)

self.aggregators: dict[str, Aggregator] = defaultdict(
cast("Callable[[], Aggregator]", aggregator.clone)
)

def update( # type: ignore[override]
self, preds: Tensor, targets: Tensor, keys: list[str], **kwargs: Any
) -> None:
kwargs_t = ({k: v[i] for k, v in kwargs.items()} for i in range(len(preds)))
Expand Down
14 changes: 9 additions & 5 deletions rationai/mlkit/metrics/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def compute(self) -> tuple[Tensor, Tensor]:
class MeanAggregator(Aggregator):
"""Aggregator to compute the mean of predictions and targets."""

preds: Tensor
targets: Tensor
count: Tensor

def __init__(self) -> None:
super().__init__()
self.add_state("preds", default=torch.tensor(0.0), dist_reduce_fx="sum")
Expand All @@ -58,18 +62,18 @@ def compute(self) -> tuple[Tensor, Tensor]:


class HeatmapAggregator(Aggregator):
preds: list[Tensor]
targets: list[Tensor]
xs: list[Tensor]
ys: list[Tensor]

"""Abstract aggregator covering the prediction heatmap generation.

Arguments:
extent_tile (int): Size of the tile.
stride_tile (int): Tile stride.
"""

preds: list[Tensor]
targets: list[Tensor]
xs: list[Tensor]
ys: list[Tensor]

def __init__(
self,
extent_tile: int,
Expand Down
14 changes: 9 additions & 5 deletions rationai/mlkit/metrics/lazy_metric_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy
from typing import Any
from typing import Any, cast

from deprecated import deprecated
from deprecated import deprecated # type: ignore[import-untyped]
from torch.nn import ModuleDict
from torchmetrics import Metric, MetricCollection

Expand All @@ -19,11 +19,15 @@ def update(self, *args: Any, key: str, **kwargs: Any) -> None: # type: ignore[o
if key not in self:
self.add_module(key, deepcopy(self.metric))

self[key].update(*args, **kwargs)
cast("Metric | MetricCollection", self[key]).update(*args, **kwargs)

def compute(self) -> dict[str, Any]:
return {k: v.compute() for k, v in self.items() if k != "metric"}
return {
k: cast("Metric | MetricCollection", v).compute()
for k, v in self.items()
if k != "metric"
}

def reset(self) -> None:
for metric in self.values():
metric.reset()
cast("Metric | MetricCollection", metric).reset()
8 changes: 4 additions & 4 deletions rationai/mlkit/metrics/nested_metric_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NestedMetricCollection(MetricCollection):
>>> # Create the NestedMetricCollection, setting 'slide' as the unique identifier for grouping. The class names are provided for multi-class metrics.
>>> nested_metrics = NestedMetricCollection(
... metrics,
... key_name="slide"
... key_name="slide",
... class_names=["A", "B", "C"],
... )

Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.class_names = class_names
self.sep = sep

def update( # pylint: disable=arguments-differ
def update( # type: ignore[override]
self, preds: Tensor, targets: Tensor, keys: list[str]
) -> None:
for pred, target, key in zip(preds, targets, keys, strict=True):
Expand All @@ -101,7 +101,7 @@ def update( # pylint: disable=arguments-differ
self[new_name].update(pred.unsqueeze(0), target.unsqueeze(0))

def compute(self) -> dict[str, Any]:
divided_metrics = defaultdict(dict)
divided_metrics: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for name, value in super().compute().items():
key, subkey = name.split(self.sep, maxsplit=1)

Expand All @@ -112,7 +112,7 @@ def compute(self) -> dict[str, Any]:
# handle multi-class metrics without averaging
assert len(value.shape) == 1
if self.class_names is None:
self.class_names = list(range(len(value)))
self.class_names = [str(i) for i in range(len(value))]

if len(value) != len(self.class_names):
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions rationai/mlkit/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from rationai.mlkit.mlflow.parquet_dataset import ParquetDataset, from_parquet


__all__ = ["ParquetDataset", "from_parquet"]
Loading