diff --git a/rationai/mlkit/data/__init__.py b/rationai/mlkit/data/__init__.py new file mode 100644 index 0000000..86b308d --- /dev/null +++ b/rationai/mlkit/data/__init__.py @@ -0,0 +1,4 @@ +from rationai.mlkit.data.shard_parquet import shard_parquet + + +__all__ = ["shard_parquet"] diff --git a/rationai/mlkit/data/shard_parquet.py b/rationai/mlkit/data/shard_parquet.py new file mode 100644 index 0000000..95d2654 --- /dev/null +++ b/rationai/mlkit/data/shard_parquet.py @@ -0,0 +1,85 @@ +import logging +from pathlib import Path + +import pyarrow.parquet as pq + + +_logger = logging.getLogger(__name__) + + +def shard_parquet( + input_file: str | Path, + output_dir: str | Path, + rows_per_shard: int = 100_000, + row_group_size: int = 5000, +) -> None: + """Splits a large Parquet file into smaller Parquet files (shards). + + This function reads a single Parquet file in memory-efficient batches and writes + it out into multiple smaller files. Each output file will contain exactly + `rows_per_shard` rows, except potentially the final shard. + + Args: + input_file (str | Path): The path to the source Parquet file. + output_dir (str | Path): The directory where the output shards will be saved. + rows_per_shard (int, optional): The target number of rows per shard. + Defaults to 100,000. + row_group_size (int, optional): The number of rows to read/write per batch. + Defaults to 5,000. + + Raises: + AssertionError: If `rows_per_shard` or `row_group_size` are not strictly positive, + or if `rows_per_shard` is not perfectly divisible by `row_group_size`. + """ + # --- Input Validation --- + assert rows_per_shard > 0, "rows_per_shard must be greater than 0" + assert row_group_size > 0, "row_group_size must be greater than 0" + + # Ensure exact chunks can be written without remainder + assert rows_per_shard % row_group_size == 0, ( + "rows_per_shard must be divisible by row_group_size" + ) + + # --- Setup Output Directory --- + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # --- Read and Shard Process --- + with pq.ParquetFile(input_file) as parquet_file: + _logger.info(f"Total rows in source: {parquet_file.metadata.num_rows}") + + # Initialize tracking variables + shard_idx = 0 + current_shard_rows = 0 + writer = None + + try: + # Iterate through the source file in memory-efficient chunks (batches) + for batch in parquet_file.iter_batches(batch_size=row_group_size): + if writer is None: + out_path = output_dir / f"shard_{shard_idx:05d}.parquet" + writer = pq.ParquetWriter(out_path, batch.schema) + + # Write the current batch + writer.write_batch(batch) + current_shard_rows += batch.num_rows + + # Check if the current shard has reached its maximum capacity + if current_shard_rows >= rows_per_shard: + writer.close() + writer = None + + _logger.info(f"Finished writing shard {shard_idx:05d}") + + shard_idx += 1 + current_shard_rows = 0 + + if writer is not None: + _logger.info(f"Finished writing final shard {shard_idx:05d}") + + finally: + # Ensure the active writer is properly closed + if writer is not None: + writer.close() + + _logger.info("Sharding complete!")