Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rationai/mlkit/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from rationai.mlkit.data.shard_parquet import shard_parquet


__all__ = ["shard_parquet"]
85 changes: 85 additions & 0 deletions rationai/mlkit/data/shard_parquet.py
Comment thread
JakubPekar marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging
from pathlib import Path

import pyarrow.parquet as pq


_logger = logging.getLogger(__name__)


def shard_parquet(
input_file: str | Path,
output_dir: str | Path,
rows_per_shard: int = 100_000,
row_group_size: int = 5000,
) -> None:
Comment thread
JakubPekar marked this conversation as resolved.
"""Splits a large Parquet file into smaller Parquet files (shards).

This function reads a single Parquet file in memory-efficient batches and writes
it out into multiple smaller files. Each output file will contain exactly
`rows_per_shard` rows, except potentially the final shard.

Args:
input_file (str | Path): The path to the source Parquet file.
output_dir (str | Path): The directory where the output shards will be saved.
rows_per_shard (int, optional): The target number of rows per shard.
Defaults to 100,000.
row_group_size (int, optional): The number of rows to read/write per batch.
Defaults to 5,000.

Raises:
AssertionError: If `rows_per_shard` or `row_group_size` are not strictly positive,
or if `rows_per_shard` is not perfectly divisible by `row_group_size`.
"""
# --- Input Validation ---
assert rows_per_shard > 0, "rows_per_shard must be greater than 0"
assert row_group_size > 0, "row_group_size must be greater than 0"

# Ensure exact chunks can be written without remainder
assert rows_per_shard % row_group_size == 0, (
"rows_per_shard must be divisible by row_group_size"
)
Comment thread
JakubPekar marked this conversation as resolved.

# --- Setup Output Directory ---
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# --- Read and Shard Process ---
with pq.ParquetFile(input_file) as parquet_file:
_logger.info(f"Total rows in source: {parquet_file.metadata.num_rows}")

# Initialize tracking variables
shard_idx = 0
current_shard_rows = 0
writer = None

try:
# Iterate through the source file in memory-efficient chunks (batches)
for batch in parquet_file.iter_batches(batch_size=row_group_size):
if writer is None:
out_path = output_dir / f"shard_{shard_idx:05d}.parquet"
writer = pq.ParquetWriter(out_path, batch.schema)

Comment thread
JakubPekar marked this conversation as resolved.
# Write the current batch
writer.write_batch(batch)
current_shard_rows += batch.num_rows

# Check if the current shard has reached its maximum capacity
if current_shard_rows >= rows_per_shard:
writer.close()
writer = None

_logger.info(f"Finished writing shard {shard_idx:05d}")

shard_idx += 1
current_shard_rows = 0

if writer is not None:
_logger.info(f"Finished writing final shard {shard_idx:05d}")

finally:
# Ensure the active writer is properly closed
if writer is not None:
writer.close()

_logger.info("Sharding complete!")
Loading