diff --git a/rationai/mlkit/data/datasets/slides_tiles_loader.py b/rationai/mlkit/data/datasets/slides_tiles_loader.py index 7359965..e950e77 100644 --- a/rationai/mlkit/data/datasets/slides_tiles_loader.py +++ b/rationai/mlkit/data/datasets/slides_tiles_loader.py @@ -76,35 +76,28 @@ def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]: if len(tiles) == 0: return {} - # 1. Grab the column directly from the underlying PyArrow Table slide_ids = tiles.data.column("slide_id") num_rows = len(slide_ids) - # 2. Handle the "Large" type conversion + # group_by requires the "large" variants for string/binary columns current_type = slide_ids.type if pa.types.is_string(current_type): slide_ids = slide_ids.cast(pa.large_string()) elif pa.types.is_binary(current_type): slide_ids = slide_ids.cast(pa.large_binary()) - # 3. Generate sequential row indices - # np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead + # np.arange is used here because PyArrow can wrap it with zero-copy overhead row_indices = pa.array(np.arange(num_rows, dtype=np.int64)) - - # 4. Combine them into a lightweight PyArrow Table table = pa.Table.from_arrays( [slide_ids, row_indices], names=["slide_id", "idx"] ) - # 5. Perform the native Arrow groupby and aggregate - # The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar + # "list" aggregates all indices for a given slide_id into a single Arrow List scalar grouped = table.group_by("slide_id").aggregate([("idx", "list")]) - # 6. Extract keys to Python, but KEEP values as PyArrow ListScalars + # Keep values as PyArrow ListScalars to avoid materializing them in Python keys = grouped.column("slide_id").to_numpy() values_array = grouped.column("idx_list") - - # Map the string key to the PyArrow ListScalar return {key: values_array[i] for i, key in enumerate(keys)} def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: diff --git a/rationai/mlkit/data/samplers/stratified_batch_sampler.py b/rationai/mlkit/data/samplers/stratified_batch_sampler.py index 9db18c0..c2603de 100644 --- a/rationai/mlkit/data/samplers/stratified_batch_sampler.py +++ b/rationai/mlkit/data/samplers/stratified_batch_sampler.py @@ -67,9 +67,9 @@ def __iter__(self) -> Iterator[list[int]]: if not len(indices[group_idx]): indices.pop(group_idx) - batch_indices = list(batch_indices) - random.shuffle(batch_indices) - yield batch_indices + batch_list = list(batch_indices) + random.shuffle(batch_list) + yield batch_list def __len__(self) -> int: return self._indices_size(self.data_indices).sum() // self.batch_size @@ -89,16 +89,14 @@ class PDMStratifiedBatchSampler(StratifiedBatchSampler): This sampler is designed to create balanced batches from a DataFrame by stratifying samples based on a specified column. - - """ def __init__( self, data: pd.DataFrame, - stratify_by: None, + stratify_by: str | list[str], batch_size: int, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: """Initializes the PDMStratifiedBatchSampler with DataFrame and batch size. diff --git a/rationai/mlkit/data/shard_parquet.py b/rationai/mlkit/data/shard_parquet.py index 95d2654..a0a4271 100644 --- a/rationai/mlkit/data/shard_parquet.py +++ b/rationai/mlkit/data/shard_parquet.py @@ -31,7 +31,6 @@ def shard_parquet( AssertionError: If `rows_per_shard` or `row_group_size` are not strictly positive, or if `rows_per_shard` is not perfectly divisible by `row_group_size`. """ - # --- Input Validation --- assert rows_per_shard > 0, "rows_per_shard must be greater than 0" assert row_group_size > 0, "row_group_size must be greater than 0" @@ -40,31 +39,25 @@ def shard_parquet( "rows_per_shard must be divisible by row_group_size" ) - # --- Setup Output Directory --- output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - # --- Read and Shard Process --- with pq.ParquetFile(input_file) as parquet_file: _logger.info(f"Total rows in source: {parquet_file.metadata.num_rows}") - # Initialize tracking variables shard_idx = 0 current_shard_rows = 0 writer = None try: - # Iterate through the source file in memory-efficient chunks (batches) for batch in parquet_file.iter_batches(batch_size=row_group_size): if writer is None: out_path = output_dir / f"shard_{shard_idx:05d}.parquet" writer = pq.ParquetWriter(out_path, batch.schema) - # Write the current batch writer.write_batch(batch) current_shard_rows += batch.num_rows - # Check if the current shard has reached its maximum capacity if current_shard_rows >= rows_per_shard: writer.close() writer = None diff --git a/rationai/mlkit/lightning/loggers/mlflow.py b/rationai/mlkit/lightning/loggers/mlflow.py index 0f56026..a65a888 100644 --- a/rationai/mlkit/lightning/loggers/mlflow.py +++ b/rationai/mlkit/lightning/loggers/mlflow.py @@ -150,9 +150,8 @@ def _log_checkpoint(self, key: str, path: str) -> None: self.run_id, tmpdir, f"{MLFLOW_CHECKPOINT_PATH}/{key}" ) - # Ensures that MLFlow logged checkpoints are in sync with those saved by the trainer. def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: - """Scan checkpoints and log them to MLFlow if not already logged.""" + """Scan checkpoints and log them to MLFlow, keeping them in sync with the trainer.""" checkpoints = self._scan_checkpoints(checkpoint_callback) logged_checkpoints = { diff --git a/rationai/mlkit/metrics/aggregated_metric_collection.py b/rationai/mlkit/metrics/aggregated_metric_collection.py index 8da1dce..3487928 100644 --- a/rationai/mlkit/metrics/aggregated_metric_collection.py +++ b/rationai/mlkit/metrics/aggregated_metric_collection.py @@ -1,6 +1,6 @@ from collections import defaultdict -from collections.abc import Sequence -from typing import Any +from collections.abc import Callable, Sequence +from typing import Any, cast from torch import Tensor from torchmetrics import Metric, MetricCollection @@ -79,11 +79,19 @@ def __init__( aggregator: Aggregator, prefix: str | None = None, ) -> None: - super().__init__(metrics, prefix=prefix) - - self.aggregators: dict[str, Aggregator] = defaultdict(aggregator.clone) - - def update( # pylint: disable=arguments-differ + super().__init__( + cast( + "Metric | MetricCollection | Sequence[Metric | MetricCollection] | dict[str, Metric | MetricCollection]", + metrics, + ), + prefix=prefix, + ) + + self.aggregators: dict[str, Aggregator] = defaultdict( + cast("Callable[[], Aggregator]", aggregator.clone) + ) + + def update( # type: ignore[override] self, preds: Tensor, targets: Tensor, keys: list[str], **kwargs: Any ) -> None: kwargs_t = ({k: v[i] for k, v in kwargs.items()} for i in range(len(preds))) diff --git a/rationai/mlkit/metrics/aggregators.py b/rationai/mlkit/metrics/aggregators.py index 47057ab..a3ffe29 100644 --- a/rationai/mlkit/metrics/aggregators.py +++ b/rationai/mlkit/metrics/aggregators.py @@ -42,6 +42,10 @@ def compute(self) -> tuple[Tensor, Tensor]: class MeanAggregator(Aggregator): """Aggregator to compute the mean of predictions and targets.""" + preds: Tensor + targets: Tensor + count: Tensor + def __init__(self) -> None: super().__init__() self.add_state("preds", default=torch.tensor(0.0), dist_reduce_fx="sum") @@ -58,11 +62,6 @@ def compute(self) -> tuple[Tensor, Tensor]: class HeatmapAggregator(Aggregator): - preds: list[Tensor] - targets: list[Tensor] - xs: list[Tensor] - ys: list[Tensor] - """Abstract aggregator covering the prediction heatmap generation. Arguments: @@ -70,6 +69,11 @@ class HeatmapAggregator(Aggregator): stride_tile (int): Tile stride. """ + preds: list[Tensor] + targets: list[Tensor] + xs: list[Tensor] + ys: list[Tensor] + def __init__( self, extent_tile: int, diff --git a/rationai/mlkit/metrics/lazy_metric_dict.py b/rationai/mlkit/metrics/lazy_metric_dict.py index a3781e5..d110788 100644 --- a/rationai/mlkit/metrics/lazy_metric_dict.py +++ b/rationai/mlkit/metrics/lazy_metric_dict.py @@ -1,7 +1,7 @@ from copy import deepcopy -from typing import Any +from typing import Any, cast -from deprecated import deprecated +from deprecated import deprecated # type: ignore[import-untyped] from torch.nn import ModuleDict from torchmetrics import Metric, MetricCollection @@ -19,11 +19,15 @@ def update(self, *args: Any, key: str, **kwargs: Any) -> None: # type: ignore[o if key not in self: self.add_module(key, deepcopy(self.metric)) - self[key].update(*args, **kwargs) + cast("Metric | MetricCollection", self[key]).update(*args, **kwargs) def compute(self) -> dict[str, Any]: - return {k: v.compute() for k, v in self.items() if k != "metric"} + return { + k: cast("Metric | MetricCollection", v).compute() + for k, v in self.items() + if k != "metric" + } def reset(self) -> None: for metric in self.values(): - metric.reset() + cast("Metric | MetricCollection", metric).reset() diff --git a/rationai/mlkit/metrics/nested_metric_collection.py b/rationai/mlkit/metrics/nested_metric_collection.py index 41e3656..1db822f 100644 --- a/rationai/mlkit/metrics/nested_metric_collection.py +++ b/rationai/mlkit/metrics/nested_metric_collection.py @@ -35,7 +35,7 @@ class NestedMetricCollection(MetricCollection): >>> # Create the NestedMetricCollection, setting 'slide' as the unique identifier for grouping. The class names are provided for multi-class metrics. >>> nested_metrics = NestedMetricCollection( ... metrics, - ... key_name="slide" + ... key_name="slide", ... class_names=["A", "B", "C"], ... ) @@ -84,7 +84,7 @@ def __init__( self.class_names = class_names self.sep = sep - def update( # pylint: disable=arguments-differ + def update( # type: ignore[override] self, preds: Tensor, targets: Tensor, keys: list[str] ) -> None: for pred, target, key in zip(preds, targets, keys, strict=True): @@ -101,7 +101,7 @@ def update( # pylint: disable=arguments-differ self[new_name].update(pred.unsqueeze(0), target.unsqueeze(0)) def compute(self) -> dict[str, Any]: - divided_metrics = defaultdict(dict) + divided_metrics: defaultdict[str, dict[str, Any]] = defaultdict(dict) for name, value in super().compute().items(): key, subkey = name.split(self.sep, maxsplit=1) @@ -112,7 +112,7 @@ def compute(self) -> dict[str, Any]: # handle multi-class metrics without averaging assert len(value.shape) == 1 if self.class_names is None: - self.class_names = list(range(len(value))) + self.class_names = [str(i) for i in range(len(value))] if len(value) != len(self.class_names): raise ValueError( diff --git a/rationai/mlkit/mlflow/__init__.py b/rationai/mlkit/mlflow/__init__.py new file mode 100644 index 0000000..46edf7e --- /dev/null +++ b/rationai/mlkit/mlflow/__init__.py @@ -0,0 +1,4 @@ +from rationai.mlkit.mlflow.parquet_dataset import ParquetDataset, from_parquet + + +__all__ = ["ParquetDataset", "from_parquet"] diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py new file mode 100644 index 0000000..107969d --- /dev/null +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -0,0 +1,166 @@ +import hashlib +import json +import logging +from functools import cached_property +from typing import Any + +import pyarrow as pa +import pyarrow.dataset as ds +from mlflow.data.dataset import Dataset +from mlflow.data.dataset_source import DatasetSource +from mlflow.types.schema import Schema +from mlflow.types.utils import _infer_schema + + +_logger = logging.getLogger(__name__) + + +class ParquetDataset(Dataset): + """Represents a lazy-loaded Parquet dataset with MLflow Tracking.""" + + def __init__( + self, + path: str, + source: DatasetSource, + target_col: str | None = None, + name: str | None = None, + digest: str | None = None, + ): + """Initializes the ParquetDataset. + + Args: + path: Local path or URI to the Parquet file or directory. + source: The source of the parquet dataset. + target_col: The name of the column representing the target variable. Optional. + name: The name of the dataset. If unspecified, a name is automatically generated. + digest: The digest (hash) of the dataset. If unspecified, a fast metadata-based + digest is automatically computed to avoid hashing massive files. + """ + self._path = path + self._target_col = target_col + + # Lazily load the dataset metadata without reading the data into memory + self._ds = ds.dataset(self._path, format="parquet") + + super().__init__(source=source, name=name, digest=digest) + + def _compute_digest(self) -> str: + """Computes a fast digest for the dataset based on schema and file paths.""" + hasher = hashlib.md5(usedforsecurity=False) + + # Hash the schema structure + hasher.update(str(self._ds.schema).encode("utf-8")) + + # Hash the sorted file paths to detect added/removed chunks + for file in sorted(self._ds.files): + hasher.update(file.encode("utf-8")) + + return hasher.hexdigest() + + def to_dict(self) -> dict[str, str]: + """Create config dictionary for the dataset.""" + config = super().to_dict() + config.update( + { + "schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}), + "profile": json.dumps(self.profile), + } + ) + return config + + @property + def source(self) -> DatasetSource: + """The source of the dataset.""" + return self._source + + @property + def dataset(self) -> ds.Dataset: + """The underlying pyarrow Dataset object.""" + return self._ds + + @property + def target_col(self) -> str | None: + """The name of the target column, if specified.""" + return self._target_col + + @cached_property + def profile(self) -> Any: + """A profile of the dataset metadata. + + Reads Parquet footers to instantly get row counts and structural metadata + without loading the actual data blocks into memory. + """ + # count_rows() sums footer row counts (falling back to a scan only when + # metadata is unavailable) using pyarrow's multi-threaded C++ implementation, + # which is much faster than a Python-level fragment loop for many shards. + total_rows = self._ds.count_rows() + + return { + "num_files": len(self._ds.files), + "total_rows": total_rows, + "num_columns": len(self._ds.schema.names), + "backend_format": "parquet", + } + + @cached_property + def schema(self) -> Schema: + """MLflow Schema representing the dataset features.""" + try: + # Create an empty PyArrow Table from the schema and convert to Pandas. + empty_df = pa.Table.from_batches([], schema=self._ds.schema).to_pandas() + inferred_schema = _infer_schema(empty_df) + return inferred_schema + except Exception as e: + _logger.warning( + f"Failed to infer schema for Parquet dataset. Exception: {e}" + ) + return Schema([]) + + +def from_parquet( + path: str, + source: str | DatasetSource | None = None, + target_col: str | None = None, + name: str | None = None, + digest: str | None = None, +) -> ParquetDataset: + """Constructs a ParquetDataset object from a single Parquet file or directory. + + Args: + path: Path to the Parquet file or directory of Parquet chunks. + source: The source from which the dataset was derived. + target_col: Optional column name for the target. + name: The name of the dataset. + digest: The dataset digest (hash). If unspecified, a metadata digest is computed. + + Example: + + .. code-block:: python + import mlflow + + # Works for both a single file and a directory of chunks + dataset = from_parquet( + path="/path/to/massive_dataset.parquet", target_col="label" + ) + mlflow.log_input(dataset, context="training") + """ + from mlflow.data.code_dataset_source import CodeDatasetSource + from mlflow.data.dataset_source_registry import resolve_dataset_source + from mlflow.tracking.context import registry + + if source is not None: + if isinstance(source, DatasetSource): + resolved_source = source + else: + resolved_source = resolve_dataset_source(source) + else: + context_tags = registry.resolve_tags() + resolved_source = CodeDatasetSource(tags=context_tags) + + return ParquetDataset( + path=path, + source=resolved_source, + target_col=target_col, + name=name, + digest=digest, + ) diff --git a/rationai/mlkit/with_cli_args.py b/rationai/mlkit/with_cli_args.py index ebddc0e..ea309ac 100644 --- a/rationai/mlkit/with_cli_args.py +++ b/rationai/mlkit/with_cli_args.py @@ -35,21 +35,13 @@ def with_cli_args( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - # 1. Save original state original_argv = sys.argv[:] - - # 2. Deconstruct existing argv - # sys.argv[0] is the script name - script_name = [sys.argv[0]] - user_provided_args = sys.argv[1:] - - # 3. Reconstruct: [Script] + [Start] + [User] + [End] + script_name, user_provided_args = sys.argv[:1], sys.argv[1:] sys.argv = script_name + prepend + user_provided_args + append try: return func(*args, **kwargs) finally: - # 4. Restore original state guarantees safety sys.argv = original_argv return wrapper diff --git a/uv.lock b/uv.lock index 6c5d097..cb69bbf 100644 --- a/uv.lock +++ b/uv.lock @@ -2287,7 +2287,7 @@ dependencies = [ [[package]] name = "rationai-mlkit" -version = "0.4.0" +version = "0.4.1" source = { virtual = "." } dependencies = [ { name = "datasets" },