From 58ea6832345b11ea6c2804ff5a84723959fc6f03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Pek=C3=A1r?= <492788@mail.muni.cz> Date: Tue, 21 Apr 2026 12:53:26 +0000 Subject: [PATCH 1/7] feat: mlflow datset --- rationai/mlkit/mlflow/parquet_dataset.py | 154 +++++++++++++++++++++++ uv.lock | 2 +- 2 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 rationai/mlkit/mlflow/parquet_dataset.py diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py new file mode 100644 index 0000000..864ef12 --- /dev/null +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -0,0 +1,154 @@ +import hashlib +import json +import logging +from functools import cached_property +from typing import Any + +import pyarrow.dataset as ds +from mlflow.data.dataset import Dataset +from mlflow.data.dataset_source import DatasetSource +from mlflow.types.schema import Schema +from mlflow.types.utils import _infer_schema + + +_logger = logging.getLogger(__name__) + + +class ParquetDataset(Dataset): + """Represents a lazy-loaded Parquet dataset with MLflow Tracking.""" + + def __init__( + self, + path: str, + source: DatasetSource, + target_col: str | None = None, + name: str | None = None, + digest: str | None = None, + ): + """Hety. + + Args: + path: Local path or URI to the Parquet file or directory. + source: The source of the parquet dataset. + target_col: The name of the column representing the target variable. Optional. + name: The name of the dataset. If unspecified, a name is automatically generated. + digest: The digest (hash) of the dataset. If unspecified, a fast metadata-based + digest is automatically computed to avoid hashing massive files. + """ + self._path = path + self._target_col = target_col + + # Lazily load the dataset metadata without reading the data into memory + self._ds = ds.dataset(self._path, format="parquet") + + super().__init__(source=source, name=name, digest=digest) + + def _compute_digest(self) -> str: + """Computes a fast digest for the dataset based on schema and file paths.""" + hasher = hashlib.md5() + + # Hash the schema structure + hasher.update(str(self._ds.schema).encode("utf-8")) + + # Hash the sorted file paths to detect added/removed chunks + for file in sorted(self._ds.files): + hasher.update(file.encode("utf-8")) + + return hasher.hexdigest() + + def to_dict(self) -> dict[str, str]: + """Create config dictionary for the dataset.""" + config = super().to_dict() + config.update({ + "schema": json.dumps(self.schema.to_dict()), + "profile": json.dumps(self.profile), + }) + return config + + @property + def source(self) -> DatasetSource: + """The source of the dataset.""" + return self._source + + @property + def dataset(self) -> ds.Dataset: + """The underlying pyarrow Dataset object.""" + return self._ds + + @property + def target_col(self) -> str | None: + """The name of the target column, if specified.""" + return self._target_col + + @property + def profile(self) -> Any: + """A lightweight profile of the dataset metadata.""" + return { + "num_files": len(self._ds.files), + "backend_format": "parquet", + } + + @cached_property + def schema(self) -> Schema: + """MLflow Schema representing the dataset features.""" + try: + # Fetch an empty Pandas dataframe from the schema to utilize MLflow's built-in + # inference securely and correctly map PyArrow types to MLflow types. + empty_df = self._ds.head(0).to_pandas() + inferred_schema = _infer_schema(empty_df) + return inferred_schema + except Exception as e: + _logger.warning(f"Failed to infer schema for Parquet dataset. Exception: {e}") + return Schema([]) + + +def from_parquet( + path: str, + source: str | DatasetSource | None = None, + target_col: str | None = None, + name: str | None = None, + digest: str | None = None, +) -> ParquetDataset: + """Constructs a ParquetDataset object from a single Parquet file or directory. + + Args: + path: Path to the Parquet file or directory of Parquet chunks. + source: The source from which the dataset was derived. + target_col: Optional column name for the target. + name: The name of the dataset. + digest: The dataset digest (hash). If unspecified, a metadata digest is computed. + + Example: + + .. code-block:: python + import mlflow + + # Works for both a single file and a directory of chunks + dataset = from_parquet( + path="/path/to/massive_dataset.parquet", + target_col="label" + ) + mlflow.log_input(dataset, context="training") + """ + from mlflow.data.code_dataset_source import CodeDatasetSource + from mlflow.data.dataset_source_registry import resolve_dataset_source + from mlflow.tracking.context import registry + + if source is not None: + if isinstance(source, DatasetSource): + resolved_source = source + else: + resolved_source = resolve_dataset_source(source) + else: + context_tags = registry.resolve_tags() + resolved_source = CodeDatasetSource(tags=context_tags) + + return ParquetDataset( + path=path, + source=resolved_source, + target_col=target_col, + name=name, + digest=digest + ) + + diff --git a/uv.lock b/uv.lock index 6c5d097..cb69bbf 100644 --- a/uv.lock +++ b/uv.lock @@ -2287,7 +2287,7 @@ dependencies = [ [[package]] name = "rationai-mlkit" -version = "0.4.0" +version = "0.4.1" source = { virtual = "." } dependencies = [ { name = "datasets" }, From 6cb54724f42e1e917fe8c00465f59bae2cf93dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Pek=C3=A1r?= <492788@mail.muni.cz> Date: Tue, 21 Apr 2026 13:30:53 +0000 Subject: [PATCH 2/7] feat: load parquet --- rationai/mlkit/mlflow/__init__.py | 4 ++++ rationai/mlkit/mlflow/parquet_dataset.py | 26 +++++++++++++++++++----- 2 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 rationai/mlkit/mlflow/__init__.py diff --git a/rationai/mlkit/mlflow/__init__.py b/rationai/mlkit/mlflow/__init__.py new file mode 100644 index 0000000..46edf7e --- /dev/null +++ b/rationai/mlkit/mlflow/__init__.py @@ -0,0 +1,4 @@ +from rationai.mlkit.mlflow.parquet_dataset import ParquetDataset, from_parquet + + +__all__ = ["ParquetDataset", "from_parquet"] diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index 864ef12..7e3c48b 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -60,7 +60,7 @@ def to_dict(self) -> dict[str, str]: """Create config dictionary for the dataset.""" config = super().to_dict() config.update({ - "schema": json.dumps(self.schema.to_dict()), + "schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}), "profile": json.dumps(self.profile), }) return config @@ -82,9 +82,27 @@ def target_col(self) -> str | None: @property def profile(self) -> Any: - """A lightweight profile of the dataset metadata.""" + """A profile of the dataset metadata. + + Reads Parquet footers to instantly get row counts and structural metadata + without loading the actual data blocks into memory. + """ + total_rows = 0 + + # Iterate over file fragments to read metadata directly from Parquet footers + for fragment in self._ds.get_fragments(): + # ParquetFileFragment natively exposes 'metadata' + if hasattr(fragment, "metadata") and fragment.metadata is not None: + chunk_rows = fragment.metadata.num_rows + else: + chunk_rows = fragment.count_rows() + + total_rows += chunk_rows + return { "num_files": len(self._ds.files), + "total_rows": total_rows, + "num_columns": len(self._ds.schema.names), "backend_format": "parquet", } @@ -149,6 +167,4 @@ def from_parquet( target_col=target_col, name=name, digest=digest - ) - - + ) \ No newline at end of file From 8411863013bd56e537627251abd98b14cf433f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Pek=C3=A1r?= <492788@mail.muni.cz> Date: Tue, 21 Apr 2026 13:33:07 +0000 Subject: [PATCH 3/7] feat: lint --- rationai/mlkit/mlflow/parquet_dataset.py | 59 +++++++++++++----------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index 7e3c48b..8c160bc 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -32,37 +32,39 @@ def __init__( source: The source of the parquet dataset. target_col: The name of the column representing the target variable. Optional. name: The name of the dataset. If unspecified, a name is automatically generated. - digest: The digest (hash) of the dataset. If unspecified, a fast metadata-based + digest: The digest (hash) of the dataset. If unspecified, a fast metadata-based digest is automatically computed to avoid hashing massive files. """ self._path = path self._target_col = target_col - + # Lazily load the dataset metadata without reading the data into memory self._ds = ds.dataset(self._path, format="parquet") - + super().__init__(source=source, name=name, digest=digest) def _compute_digest(self) -> str: """Computes a fast digest for the dataset based on schema and file paths.""" hasher = hashlib.md5() - + # Hash the schema structure hasher.update(str(self._ds.schema).encode("utf-8")) - + # Hash the sorted file paths to detect added/removed chunks for file in sorted(self._ds.files): hasher.update(file.encode("utf-8")) - + return hasher.hexdigest() def to_dict(self) -> dict[str, str]: """Create config dictionary for the dataset.""" config = super().to_dict() - config.update({ - "schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}), - "profile": json.dumps(self.profile), - }) + config.update( + { + "schema": json.dumps({"mlflow_colspec": self.schema.to_dict()}), + "profile": json.dumps(self.profile), + } + ) return config @property @@ -84,11 +86,11 @@ def target_col(self) -> str | None: def profile(self) -> Any: """A profile of the dataset metadata. - Reads Parquet footers to instantly get row counts and structural metadata + Reads Parquet footers to instantly get row counts and structural metadata without loading the actual data blocks into memory. """ total_rows = 0 - + # Iterate over file fragments to read metadata directly from Parquet footers for fragment in self._ds.get_fragments(): # ParquetFileFragment natively exposes 'metadata' @@ -96,7 +98,7 @@ def profile(self) -> Any: chunk_rows = fragment.metadata.num_rows else: chunk_rows = fragment.count_rows() - + total_rows += chunk_rows return { @@ -110,13 +112,15 @@ def profile(self) -> Any: def schema(self) -> Schema: """MLflow Schema representing the dataset features.""" try: - # Fetch an empty Pandas dataframe from the schema to utilize MLflow's built-in + # Fetch an empty Pandas dataframe from the schema to utilize MLflow's built-in # inference securely and correctly map PyArrow types to MLflow types. empty_df = self._ds.head(0).to_pandas() inferred_schema = _infer_schema(empty_df) return inferred_schema except Exception as e: - _logger.warning(f"Failed to infer schema for Parquet dataset. Exception: {e}") + _logger.warning( + f"Failed to infer schema for Parquet dataset. Exception: {e}" + ) return Schema([]) @@ -128,23 +132,22 @@ def from_parquet( digest: str | None = None, ) -> ParquetDataset: """Constructs a ParquetDataset object from a single Parquet file or directory. - + Args: path: Path to the Parquet file or directory of Parquet chunks. - source: The source from which the dataset was derived. + source: The source from which the dataset was derived. target_col: Optional column name for the target. name: The name of the dataset. digest: The dataset digest (hash). If unspecified, a metadata digest is computed. Example: - + .. code-block:: python import mlflow - + # Works for both a single file and a directory of chunks dataset = from_parquet( - path="/path/to/massive_dataset.parquet", - target_col="label" + path="/path/to/massive_dataset.parquet", target_col="label" ) mlflow.log_input(dataset, context="training") """ @@ -160,11 +163,11 @@ def from_parquet( else: context_tags = registry.resolve_tags() resolved_source = CodeDatasetSource(tags=context_tags) - + return ParquetDataset( - path=path, - source=resolved_source, - target_col=target_col, - name=name, - digest=digest - ) \ No newline at end of file + path=path, + source=resolved_source, + target_col=target_col, + name=name, + digest=digest, + ) From 573d7b479f90fafd7b49bdbe25a49ab78010151a Mon Sep 17 00:00:00 2001 From: JakubPekar Date: Wed, 24 Jun 2026 10:43:21 +0200 Subject: [PATCH 4/7] feat: fix mypy --- .../data/datasets/slides_tiles_loader.py | 15 ++++--------- .../data/samplers/stratified_batch_sampler.py | 10 ++++----- rationai/mlkit/data/shard_parquet.py | 7 ------ rationai/mlkit/lightning/loggers/mlflow.py | 3 +-- .../metrics/aggregated_metric_collection.py | 22 +++++++++++++------ rationai/mlkit/metrics/aggregators.py | 14 +++++++----- rationai/mlkit/metrics/lazy_metric_dict.py | 14 +++++++----- .../mlkit/metrics/nested_metric_collection.py | 8 +++---- rationai/mlkit/mlflow/parquet_dataset.py | 17 +++++--------- rationai/mlkit/with_cli_args.py | 10 +-------- 10 files changed, 52 insertions(+), 68 deletions(-) diff --git a/rationai/mlkit/data/datasets/slides_tiles_loader.py b/rationai/mlkit/data/datasets/slides_tiles_loader.py index 7359965..e950e77 100644 --- a/rationai/mlkit/data/datasets/slides_tiles_loader.py +++ b/rationai/mlkit/data/datasets/slides_tiles_loader.py @@ -76,35 +76,28 @@ def _build_tile_index(tiles: HFDataset) -> dict[str | bytes, pa.ListScalar]: if len(tiles) == 0: return {} - # 1. Grab the column directly from the underlying PyArrow Table slide_ids = tiles.data.column("slide_id") num_rows = len(slide_ids) - # 2. Handle the "Large" type conversion + # group_by requires the "large" variants for string/binary columns current_type = slide_ids.type if pa.types.is_string(current_type): slide_ids = slide_ids.cast(pa.large_string()) elif pa.types.is_binary(current_type): slide_ids = slide_ids.cast(pa.large_binary()) - # 3. Generate sequential row indices - # np.arange is used here because PyArrow can wrap it instantly with zero-copy overhead + # np.arange is used here because PyArrow can wrap it with zero-copy overhead row_indices = pa.array(np.arange(num_rows, dtype=np.int64)) - - # 4. Combine them into a lightweight PyArrow Table table = pa.Table.from_arrays( [slide_ids, row_indices], names=["slide_id", "idx"] ) - # 5. Perform the native Arrow groupby and aggregate - # The "list" function aggregates all indices for a given slide_id into a single Arrow List scalar + # "list" aggregates all indices for a given slide_id into a single Arrow List scalar grouped = table.group_by("slide_id").aggregate([("idx", "list")]) - # 6. Extract keys to Python, but KEEP values as PyArrow ListScalars + # Keep values as PyArrow ListScalars to avoid materializing them in Python keys = grouped.column("slide_id").to_numpy() values_array = grouped.column("idx_list") - - # Map the string key to the PyArrow ListScalar return {key: values_array[i] for i, key in enumerate(keys)} def filter_tiles_by_slide(self, slide_id: str | bytes) -> HFDataset: diff --git a/rationai/mlkit/data/samplers/stratified_batch_sampler.py b/rationai/mlkit/data/samplers/stratified_batch_sampler.py index 9db18c0..dcf332f 100644 --- a/rationai/mlkit/data/samplers/stratified_batch_sampler.py +++ b/rationai/mlkit/data/samplers/stratified_batch_sampler.py @@ -67,9 +67,9 @@ def __iter__(self) -> Iterator[list[int]]: if not len(indices[group_idx]): indices.pop(group_idx) - batch_indices = list(batch_indices) - random.shuffle(batch_indices) - yield batch_indices + batch_list = list(batch_indices) + random.shuffle(batch_list) + yield batch_list def __len__(self) -> int: return self._indices_size(self.data_indices).sum() // self.batch_size @@ -89,14 +89,12 @@ class PDMStratifiedBatchSampler(StratifiedBatchSampler): This sampler is designed to create balanced batches from a DataFrame by stratifying samples based on a specified column. - - """ def __init__( self, data: pd.DataFrame, - stratify_by: None, + stratify_by: str | list[str], batch_size: int, **kwargs: dict[str, Any], ) -> None: diff --git a/rationai/mlkit/data/shard_parquet.py b/rationai/mlkit/data/shard_parquet.py index 95d2654..a0a4271 100644 --- a/rationai/mlkit/data/shard_parquet.py +++ b/rationai/mlkit/data/shard_parquet.py @@ -31,7 +31,6 @@ def shard_parquet( AssertionError: If `rows_per_shard` or `row_group_size` are not strictly positive, or if `rows_per_shard` is not perfectly divisible by `row_group_size`. """ - # --- Input Validation --- assert rows_per_shard > 0, "rows_per_shard must be greater than 0" assert row_group_size > 0, "row_group_size must be greater than 0" @@ -40,31 +39,25 @@ def shard_parquet( "rows_per_shard must be divisible by row_group_size" ) - # --- Setup Output Directory --- output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - # --- Read and Shard Process --- with pq.ParquetFile(input_file) as parquet_file: _logger.info(f"Total rows in source: {parquet_file.metadata.num_rows}") - # Initialize tracking variables shard_idx = 0 current_shard_rows = 0 writer = None try: - # Iterate through the source file in memory-efficient chunks (batches) for batch in parquet_file.iter_batches(batch_size=row_group_size): if writer is None: out_path = output_dir / f"shard_{shard_idx:05d}.parquet" writer = pq.ParquetWriter(out_path, batch.schema) - # Write the current batch writer.write_batch(batch) current_shard_rows += batch.num_rows - # Check if the current shard has reached its maximum capacity if current_shard_rows >= rows_per_shard: writer.close() writer = None diff --git a/rationai/mlkit/lightning/loggers/mlflow.py b/rationai/mlkit/lightning/loggers/mlflow.py index 0f56026..a65a888 100644 --- a/rationai/mlkit/lightning/loggers/mlflow.py +++ b/rationai/mlkit/lightning/loggers/mlflow.py @@ -150,9 +150,8 @@ def _log_checkpoint(self, key: str, path: str) -> None: self.run_id, tmpdir, f"{MLFLOW_CHECKPOINT_PATH}/{key}" ) - # Ensures that MLFlow logged checkpoints are in sync with those saved by the trainer. def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: - """Scan checkpoints and log them to MLFlow if not already logged.""" + """Scan checkpoints and log them to MLFlow, keeping them in sync with the trainer.""" checkpoints = self._scan_checkpoints(checkpoint_callback) logged_checkpoints = { diff --git a/rationai/mlkit/metrics/aggregated_metric_collection.py b/rationai/mlkit/metrics/aggregated_metric_collection.py index 8da1dce..3487928 100644 --- a/rationai/mlkit/metrics/aggregated_metric_collection.py +++ b/rationai/mlkit/metrics/aggregated_metric_collection.py @@ -1,6 +1,6 @@ from collections import defaultdict -from collections.abc import Sequence -from typing import Any +from collections.abc import Callable, Sequence +from typing import Any, cast from torch import Tensor from torchmetrics import Metric, MetricCollection @@ -79,11 +79,19 @@ def __init__( aggregator: Aggregator, prefix: str | None = None, ) -> None: - super().__init__(metrics, prefix=prefix) - - self.aggregators: dict[str, Aggregator] = defaultdict(aggregator.clone) - - def update( # pylint: disable=arguments-differ + super().__init__( + cast( + "Metric | MetricCollection | Sequence[Metric | MetricCollection] | dict[str, Metric | MetricCollection]", + metrics, + ), + prefix=prefix, + ) + + self.aggregators: dict[str, Aggregator] = defaultdict( + cast("Callable[[], Aggregator]", aggregator.clone) + ) + + def update( # type: ignore[override] self, preds: Tensor, targets: Tensor, keys: list[str], **kwargs: Any ) -> None: kwargs_t = ({k: v[i] for k, v in kwargs.items()} for i in range(len(preds))) diff --git a/rationai/mlkit/metrics/aggregators.py b/rationai/mlkit/metrics/aggregators.py index 47057ab..a3ffe29 100644 --- a/rationai/mlkit/metrics/aggregators.py +++ b/rationai/mlkit/metrics/aggregators.py @@ -42,6 +42,10 @@ def compute(self) -> tuple[Tensor, Tensor]: class MeanAggregator(Aggregator): """Aggregator to compute the mean of predictions and targets.""" + preds: Tensor + targets: Tensor + count: Tensor + def __init__(self) -> None: super().__init__() self.add_state("preds", default=torch.tensor(0.0), dist_reduce_fx="sum") @@ -58,11 +62,6 @@ def compute(self) -> tuple[Tensor, Tensor]: class HeatmapAggregator(Aggregator): - preds: list[Tensor] - targets: list[Tensor] - xs: list[Tensor] - ys: list[Tensor] - """Abstract aggregator covering the prediction heatmap generation. Arguments: @@ -70,6 +69,11 @@ class HeatmapAggregator(Aggregator): stride_tile (int): Tile stride. """ + preds: list[Tensor] + targets: list[Tensor] + xs: list[Tensor] + ys: list[Tensor] + def __init__( self, extent_tile: int, diff --git a/rationai/mlkit/metrics/lazy_metric_dict.py b/rationai/mlkit/metrics/lazy_metric_dict.py index a3781e5..d110788 100644 --- a/rationai/mlkit/metrics/lazy_metric_dict.py +++ b/rationai/mlkit/metrics/lazy_metric_dict.py @@ -1,7 +1,7 @@ from copy import deepcopy -from typing import Any +from typing import Any, cast -from deprecated import deprecated +from deprecated import deprecated # type: ignore[import-untyped] from torch.nn import ModuleDict from torchmetrics import Metric, MetricCollection @@ -19,11 +19,15 @@ def update(self, *args: Any, key: str, **kwargs: Any) -> None: # type: ignore[o if key not in self: self.add_module(key, deepcopy(self.metric)) - self[key].update(*args, **kwargs) + cast("Metric | MetricCollection", self[key]).update(*args, **kwargs) def compute(self) -> dict[str, Any]: - return {k: v.compute() for k, v in self.items() if k != "metric"} + return { + k: cast("Metric | MetricCollection", v).compute() + for k, v in self.items() + if k != "metric" + } def reset(self) -> None: for metric in self.values(): - metric.reset() + cast("Metric | MetricCollection", metric).reset() diff --git a/rationai/mlkit/metrics/nested_metric_collection.py b/rationai/mlkit/metrics/nested_metric_collection.py index 41e3656..1db822f 100644 --- a/rationai/mlkit/metrics/nested_metric_collection.py +++ b/rationai/mlkit/metrics/nested_metric_collection.py @@ -35,7 +35,7 @@ class NestedMetricCollection(MetricCollection): >>> # Create the NestedMetricCollection, setting 'slide' as the unique identifier for grouping. The class names are provided for multi-class metrics. >>> nested_metrics = NestedMetricCollection( ... metrics, - ... key_name="slide" + ... key_name="slide", ... class_names=["A", "B", "C"], ... ) @@ -84,7 +84,7 @@ def __init__( self.class_names = class_names self.sep = sep - def update( # pylint: disable=arguments-differ + def update( # type: ignore[override] self, preds: Tensor, targets: Tensor, keys: list[str] ) -> None: for pred, target, key in zip(preds, targets, keys, strict=True): @@ -101,7 +101,7 @@ def update( # pylint: disable=arguments-differ self[new_name].update(pred.unsqueeze(0), target.unsqueeze(0)) def compute(self) -> dict[str, Any]: - divided_metrics = defaultdict(dict) + divided_metrics: defaultdict[str, dict[str, Any]] = defaultdict(dict) for name, value in super().compute().items(): key, subkey = name.split(self.sep, maxsplit=1) @@ -112,7 +112,7 @@ def compute(self) -> dict[str, Any]: # handle multi-class metrics without averaging assert len(value.shape) == 1 if self.class_names is None: - self.class_names = list(range(len(value))) + self.class_names = [str(i) for i in range(len(value))] if len(value) != len(self.class_names): raise ValueError( diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index 8c160bc..b8717a9 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -82,24 +82,17 @@ def target_col(self) -> str | None: """The name of the target column, if specified.""" return self._target_col - @property + @cached_property def profile(self) -> Any: """A profile of the dataset metadata. Reads Parquet footers to instantly get row counts and structural metadata without loading the actual data blocks into memory. """ - total_rows = 0 - - # Iterate over file fragments to read metadata directly from Parquet footers - for fragment in self._ds.get_fragments(): - # ParquetFileFragment natively exposes 'metadata' - if hasattr(fragment, "metadata") and fragment.metadata is not None: - chunk_rows = fragment.metadata.num_rows - else: - chunk_rows = fragment.count_rows() - - total_rows += chunk_rows + # count_rows() sums footer row counts (falling back to a scan only when + # metadata is unavailable) using pyarrow's multi-threaded C++ implementation, + # which is much faster than a Python-level fragment loop for many shards. + total_rows = self._ds.count_rows() return { "num_files": len(self._ds.files), diff --git a/rationai/mlkit/with_cli_args.py b/rationai/mlkit/with_cli_args.py index ebddc0e..ea309ac 100644 --- a/rationai/mlkit/with_cli_args.py +++ b/rationai/mlkit/with_cli_args.py @@ -35,21 +35,13 @@ def with_cli_args( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - # 1. Save original state original_argv = sys.argv[:] - - # 2. Deconstruct existing argv - # sys.argv[0] is the script name - script_name = [sys.argv[0]] - user_provided_args = sys.argv[1:] - - # 3. Reconstruct: [Script] + [Start] + [User] + [End] + script_name, user_provided_args = sys.argv[:1], sys.argv[1:] sys.argv = script_name + prepend + user_provided_args + append try: return func(*args, **kwargs) finally: - # 4. Restore original state guarantees safety sys.argv = original_argv return wrapper From c5da4b8bacf9ff6635f56ef7aab30e522dab21a6 Mon Sep 17 00:00:00 2001 From: JakubPekar Date: Wed, 24 Jun 2026 10:49:04 +0200 Subject: [PATCH 5/7] feat: docs --- rationai/mlkit/mlflow/parquet_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index b8717a9..c9da4e7 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -25,7 +25,7 @@ def __init__( name: str | None = None, digest: str | None = None, ): - """Hety. + """Initializes the ParquetDataset. Args: path: Local path or URI to the Parquet file or directory. From 689cd6c90a8f8cd3b409732c7678028e885e8092 Mon Sep 17 00:00:00 2001 From: JakubPekar Date: Wed, 24 Jun 2026 10:50:50 +0200 Subject: [PATCH 6/7] feat: CR fix --- rationai/mlkit/mlflow/parquet_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index c9da4e7..862484c 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -45,7 +45,7 @@ def __init__( def _compute_digest(self) -> str: """Computes a fast digest for the dataset based on schema and file paths.""" - hasher = hashlib.md5() + hasher = hashlib.md5(usedforsecurity=False) # Hash the schema structure hasher.update(str(self._ds.schema).encode("utf-8")) From c9b00eec370d00be4320da9f9ae25d37503bb86a Mon Sep 17 00:00:00 2001 From: JakubPekar Date: Thu, 25 Jun 2026 11:38:56 +0200 Subject: [PATCH 7/7] fix: CR --- rationai/mlkit/data/samplers/stratified_batch_sampler.py | 2 +- rationai/mlkit/mlflow/parquet_dataset.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rationai/mlkit/data/samplers/stratified_batch_sampler.py b/rationai/mlkit/data/samplers/stratified_batch_sampler.py index dcf332f..c2603de 100644 --- a/rationai/mlkit/data/samplers/stratified_batch_sampler.py +++ b/rationai/mlkit/data/samplers/stratified_batch_sampler.py @@ -96,7 +96,7 @@ def __init__( data: pd.DataFrame, stratify_by: str | list[str], batch_size: int, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: """Initializes the PDMStratifiedBatchSampler with DataFrame and batch size. diff --git a/rationai/mlkit/mlflow/parquet_dataset.py b/rationai/mlkit/mlflow/parquet_dataset.py index 862484c..107969d 100644 --- a/rationai/mlkit/mlflow/parquet_dataset.py +++ b/rationai/mlkit/mlflow/parquet_dataset.py @@ -4,6 +4,7 @@ from functools import cached_property from typing import Any +import pyarrow as pa import pyarrow.dataset as ds from mlflow.data.dataset import Dataset from mlflow.data.dataset_source import DatasetSource @@ -105,9 +106,8 @@ def profile(self) -> Any: def schema(self) -> Schema: """MLflow Schema representing the dataset features.""" try: - # Fetch an empty Pandas dataframe from the schema to utilize MLflow's built-in - # inference securely and correctly map PyArrow types to MLflow types. - empty_df = self._ds.head(0).to_pandas() + # Create an empty PyArrow Table from the schema and convert to Pandas. + empty_df = pa.Table.from_batches([], schema=self._ds.schema).to_pandas() inferred_schema = _infer_schema(empty_df) return inferred_schema except Exception as e: