From 27e07a772761990ed93ba19df2dcd2c29d886369 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:13:08 -0600 Subject: [PATCH 01/20] feat: add spark.comet.cache.serializer.enabled config --- .../src/main/scala/org/apache/comet/CometConf.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 78ea0f0168..69baf97204 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -226,6 +226,19 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_CACHE_SERIALIZER_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.cache.serializer.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, Comet installs a CachedBatchSerializer that stores Spark's in-memory " + + "table cache as compressed Arrow IPC. Repeated scans of cached data are then read " + + "natively without a per-read conversion. Schemas Comet cannot handle transparently " + + "fall back to Spark's default cache serializer. This sets " + + "spark.sql.cache.serializer for the session unless that property is already set " + + "to a non-default value. Disabled by default.") + .booleanConf + .createWithDefault(false) + val COMET_EXEC_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.enabled") .category(CATEGORY_EXEC) .doc( From 0a9c233612bcda087d4626c089c5400b5e3b0422 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:18:35 -0600 Subject: [PATCH 02/20] feat: add CometCachedBatch and column stats helper --- .../comet/CometCachedBatchSerializer.scala | 112 ++++++++++++++++++ .../CometCachedBatchSerializerSuite.scala | 56 +++++++++ 2 files changed, 168 insertions(+) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala create mode 100644 spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala new file mode 100644 index 0000000000..f1d029dd26 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A cached batch holding one compressed Arrow IPC message plus Spark-format column stats. + * + * @param numRows + * number of rows in this batch + * @param bytes + * compressed Arrow IPC bytes for a single record batch + * @param stats + * InternalRow laid out as ColumnStatisticsSchema expects: per column [lowerBound, upperBound, + * nullCount, count, sizeInBytes] + */ +case class CometCachedBatch(numRows: Int, bytes: Array[Byte], stats: InternalRow) + extends SimpleMetricsCachedBatch { + // Used by InMemoryRelation to estimate the cached relation size; must reflect real bytes. + override def sizeInBytes: Long = bytes.length.toLong +} + +/** + * Accumulates per-column min/max/null/count for a set of rows and emits the stats InternalRow in + * the exact layout Spark's ColumnStatisticsSchema / SimpleMetricsCachedBatchSerializer expects. + * + * For column data types where a total ordering is not implemented here, the lower/upper bounds + * are left null. Null bounds mean "cannot prune" and are always correct (this is how Spark itself + * encodes unknown stats). + */ +class CometCacheColumnStats(attributes: Seq[Attribute]) { + private val numCols = attributes.length + private val lower = new Array[Any](numCols) + private val upper = new Array[Any](numCols) + private val nulls = new Array[Int](numCols) + private var rowCount = 0 + + /** Update column `ordinal` with one value. `value` is in Catalyst internal form (or null). */ + def update(ordinal: Int, dt: DataType, isNull: Boolean, value: Any): Unit = { + if (isNull) { + nulls(ordinal) += 1 + return + } + if (!ordered(dt)) return // leave bounds null for unsupported-stat types + if (lower(ordinal) == null || compare(dt, value, lower(ordinal)) < 0) lower(ordinal) = value + if (upper(ordinal) == null || compare(dt, value, upper(ordinal)) > 0) upper(ordinal) = value + } + + def setRowCount(n: Int): Unit = rowCount = n + + def toInternalRow: InternalRow = { + val values = new Array[Any](numCols * 5) + var i = 0 + while (i < numCols) { + val base = i * 5 + values(base) = lower(i) // lowerBound (column data type or null) + values(base + 1) = upper(i) // upperBound + values(base + 2) = nulls(i) // nullCount (Int) + values(base + 3) = rowCount // count (Int) + values(base + 4) = 0L // sizeInBytes (Long); not used by buildFilter + i += 1 + } + new GenericInternalRow(values) + } + + private def ordered(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType => + true + case _ => false + } + + private def compare(dt: DataType, x: Any, y: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(x.asInstanceOf[Boolean], y.asInstanceOf[Boolean]) + case ByteType => java.lang.Byte.compare(x.asInstanceOf[Byte], y.asInstanceOf[Byte]) + case ShortType => java.lang.Short.compare(x.asInstanceOf[Short], y.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(x.asInstanceOf[Int], y.asInstanceOf[Int]) + case LongType | TimestampType => + java.lang.Long.compare(x.asInstanceOf[Long], y.asInstanceOf[Long]) + case FloatType => java.lang.Float.compare(x.asInstanceOf[Float], y.asInstanceOf[Float]) + case DoubleType => java.lang.Double.compare(x.asInstanceOf[Double], y.asInstanceOf[Double]) + case _: DecimalType => + x.asInstanceOf[org.apache.spark.sql.types.Decimal] + .compare(y.asInstanceOf[org.apache.spark.sql.types.Decimal]) + case StringType => x.asInstanceOf[UTF8String].binaryCompare(y.asInstanceOf[UTF8String]) + case other => throw new IllegalStateException(s"compare called for unordered type $other") + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala new file mode 100644 index 0000000000..6d4d26ca5e --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class CometCachedBatchSerializerSuite extends CometTestBase { + + test("stats row has 5 fields per column in cachedAttributes order") { + val a = AttributeReference("a", IntegerType, nullable = true)() + val b = AttributeReference("b", StringType, nullable = true)() + val acc = new CometCacheColumnStats(Seq(a, b)) + // column 0: values 5, null, 3 ; column 1: "y", "a", null + acc.update(0, IntegerType, isNull = false, 5) + acc.update(0, IntegerType, isNull = true, null) + acc.update(0, IntegerType, isNull = false, 3) + acc.update(1, StringType, isNull = false, UTF8String.fromString("y")) + acc.update(1, StringType, isNull = false, UTF8String.fromString("a")) + acc.update(1, StringType, isNull = true, null) + acc.setRowCount(3) + val stats = acc.toInternalRow + + assert(stats.numFields == 10) // 5 fields * 2 columns + // column 0: [lower=3, upper=5, nullCount=1, count=3, sizeInBytes=0] + assert(stats.getInt(0) == 3) + assert(stats.getInt(1) == 5) + assert(stats.getInt(2) == 1) + assert(stats.getInt(3) == 3) + // column 1: [lower="a", upper="y", nullCount=1, count=3, sizeInBytes=0] + assert(stats.getUTF8String(5) == UTF8String.fromString("a")) + assert(stats.getUTF8String(6) == UTF8String.fromString("y")) + assert(stats.getInt(7) == 1) + assert(stats.getInt(8) == 3) + } +} From 82517f7222ec0f0371af3d536680de33f4a2b4f0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:30:13 -0600 Subject: [PATCH 03/20] fix: use ByteArray.compareBinary for version-safe string stats ordering --- .../apache/spark/sql/comet/CometCachedBatchSerializer.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index f1d029dd26..aacf523029 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.ByteArray import org.apache.spark.unsafe.types.UTF8String /** @@ -106,7 +107,10 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { case _: DecimalType => x.asInstanceOf[org.apache.spark.sql.types.Decimal] .compare(y.asInstanceOf[org.apache.spark.sql.types.Decimal]) - case StringType => x.asInstanceOf[UTF8String].binaryCompare(y.asInstanceOf[UTF8String]) + case StringType => + ByteArray.compareBinary( + x.asInstanceOf[UTF8String].getBytes, + y.asInstanceOf[UTF8String].getBytes) case other => throw new IllegalStateException(s"compare called for unordered type $other") } } From cf3a73c1442c603802138c3020382ffd41b53c5f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:36:47 -0600 Subject: [PATCH 04/20] test: exercise CometCachedBatch and document setRowCount ordering --- .../spark/sql/comet/CometCachedBatchSerializer.scala | 5 +++++ .../org/apache/comet/CometCachedBatchSerializerSuite.scala | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index aacf523029..9ac81d33d0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -69,6 +69,11 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { if (upper(ordinal) == null || compare(dt, value, upper(ordinal)) > 0) upper(ordinal) = value } + /** + * Sets the total row count for this batch (the `count` stat field). Must be called before + * `toInternalRow`; otherwise `count` stays 0 and predicates like IsNotNull could incorrectly + * prune a non-empty batch. + */ def setRowCount(n: Int): Unit = rowCount = n def toInternalRow: InternalRow = { diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 6d4d26ca5e..1a14454613 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -52,5 +52,12 @@ class CometCachedBatchSerializerSuite extends CometTestBase { assert(stats.getUTF8String(6) == UTF8String.fromString("y")) assert(stats.getInt(7) == 1) assert(stats.getInt(8) == 3) + // sizeInBytes stat slots (positions 4 and 9) are 0L; they are not used by buildFilter + assert(stats.getLong(4) == 0L) + assert(stats.getLong(9) == 0L) + // CometCachedBatch.sizeInBytes reflects the IPC byte length + val cb = CometCachedBatch(numRows = 3, bytes = Array[Byte](1, 2, 3, 4, 5), stats = stats) + assert(cb.sizeInBytes == 5L) + assert(cb.numRows == 3) } } From 6c80ef0ef918f1030d594cc9fc609ae77e67d31c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:41:20 -0600 Subject: [PATCH 05/20] feat: add CometCachedBatchSerializer skeleton with schema routing --- .../comet/CometCachedBatchSerializer.scala | 61 ++++++++++++++++++- .../CometCachedBatchSerializerSuite.scala | 9 +++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index 9ac81d33d0..9f6e2a42a1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.comet +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} -import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch +import org.apache.spark.sql.columnar.{SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.ByteArray import org.apache.spark.unsafe.types.UTF8String @@ -119,3 +122,59 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { case other => throw new IllegalStateException(s"compare called for unordered type $other") } } + +class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer with Logging { + + // Delegate target for schemas Comet does not handle. Serializable (no-arg constructor). + private val fallback = new DefaultCachedBatchSerializer + + /** Comet handles flat schemas of the data types its Arrow conversion supports. */ + private def isCometSchema(dataTypes: Seq[DataType]): Boolean = + dataTypes.forall(isCometType) + + private def isCometType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | BinaryType | DateType | TimestampType => + true + // Nested/complex types are out of scope for v1; delegate to the default serializer. + case _ => false + } + + private def cometSchema(attrs: Seq[Attribute]): Boolean = isCometSchema(attrs.map(_.dataType)) + + // Force the row build path for Comet schemas (single code path for encode + stats); delegate + // otherwise so the default serializer's columnar-input optimization still applies. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = + if (cometSchema(schema)) false else fallback.supportsColumnarInput(schema) + + override def supportsColumnarOutput(schema: StructType): Boolean = + if (isCometSchema(schema.map(_.dataType))) true else fallback.supportsColumnarOutput(schema) + + // Let Spark use generic ColumnVector access; our columns are heterogeneous CometVector subtypes. + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None + + override def convertInternalRowToCachedBatch( + input: org.apache.spark.rdd.RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: org.apache.spark.storage.StorageLevel, + conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch] = ??? + + override def convertColumnarBatchToCachedBatch( + input: org.apache.spark.rdd.RDD[org.apache.spark.sql.vectorized.ColumnarBatch], + schema: Seq[Attribute], + storageLevel: org.apache.spark.storage.StorageLevel, + conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch] = ??? + + override def convertCachedBatchToColumnarBatch( + input: org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.vectorized.ColumnarBatch] = + ??? + + override def convertCachedBatchToInternalRow( + input: org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): org.apache.spark.rdd.RDD[InternalRow] = ??? +} diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 1a14454613..25af67c1db 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -60,4 +60,13 @@ class CometCachedBatchSerializerSuite extends CometTestBase { assert(cb.sizeInBytes == 5L) assert(cb.numRows == 3) } + + test("supportsColumnarOutput: true for flat supported schema, delegated for nested") { + val ser = new org.apache.spark.sql.comet.CometCachedBatchSerializer + val flat = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))) + val nested = StructType(Seq(StructField("a", ArrayType(IntegerType)))) + assert(ser.supportsColumnarOutput(flat)) + // nested delegates to DefaultCachedBatchSerializer, which does not support columnar output + assert(!ser.supportsColumnarOutput(nested)) + } } From ec936f9f33d9be53399ec170f8fa17a41f46a04a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:49:01 -0600 Subject: [PATCH 06/20] refactor: simplify CometCachedBatchSerializer schema routing and imports --- .../comet/CometCachedBatchSerializer.scala | 33 +++++++++---------- .../CometCachedBatchSerializerSuite.scala | 4 +-- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index 9f6e2a42a1..ee8e156c64 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.comet -import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} -import org.apache.spark.sql.columnar.{SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} import org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.ByteArray import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +125,7 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { } } -class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer with Logging { +class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { // Delegate target for schemas Comet does not handle. Serializable (no-arg constructor). private val fallback = new DefaultCachedBatchSerializer @@ -140,12 +142,10 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer with case _ => false } - private def cometSchema(attrs: Seq[Attribute]): Boolean = isCometSchema(attrs.map(_.dataType)) - // Force the row build path for Comet schemas (single code path for encode + stats); delegate // otherwise so the default serializer's columnar-input optimization still applies. override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = - if (cometSchema(schema)) false else fallback.supportsColumnarInput(schema) + if (isCometSchema(schema.map(_.dataType))) false else fallback.supportsColumnarInput(schema) override def supportsColumnarOutput(schema: StructType): Boolean = if (isCometSchema(schema.map(_.dataType))) true else fallback.supportsColumnarOutput(schema) @@ -154,27 +154,26 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer with override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None override def convertInternalRowToCachedBatch( - input: org.apache.spark.rdd.RDD[InternalRow], + input: RDD[InternalRow], schema: Seq[Attribute], - storageLevel: org.apache.spark.storage.StorageLevel, - conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch] = ??? + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = ??? override def convertColumnarBatchToCachedBatch( - input: org.apache.spark.rdd.RDD[org.apache.spark.sql.vectorized.ColumnarBatch], + input: RDD[ColumnarBatch], schema: Seq[Attribute], - storageLevel: org.apache.spark.storage.StorageLevel, - conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch] = ??? + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = ??? override def convertCachedBatchToColumnarBatch( - input: org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch], + input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): org.apache.spark.rdd.RDD[org.apache.spark.sql.vectorized.ColumnarBatch] = - ??? + conf: SQLConf): RDD[ColumnarBatch] = ??? override def convertCachedBatchToInternalRow( - input: org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch], + input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): org.apache.spark.rdd.RDD[InternalRow] = ??? + conf: SQLConf): RDD[InternalRow] = ??? } diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 25af67c1db..7172caeff8 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -21,7 +21,7 @@ package org.apache.comet import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch} +import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, CometCachedBatchSerializer} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -62,7 +62,7 @@ class CometCachedBatchSerializerSuite extends CometTestBase { } test("supportsColumnarOutput: true for flat supported schema, delegated for nested") { - val ser = new org.apache.spark.sql.comet.CometCachedBatchSerializer + val ser = new CometCachedBatchSerializer val flat = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))) val nested = StructType(Seq(StructField("a", ArrayType(IntegerType)))) assert(ser.supportsColumnarOutput(flat)) From fe216191974df49d59e155ab53802a107c08a48a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 08:59:39 -0600 Subject: [PATCH 07/20] feat: encode cached batches as compressed Arrow IPC with stats --- .../comet/CometCachedBatchSerializer.scala | 102 +++++++++++++++++- .../CometCachedBatchSerializerSuite.scala | 24 +++++ 2 files changed, 121 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index ee8e156c64..f8d0eb8969 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.comet +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.ByteArray import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.CometConf + /** * A cached batch holding one compressed Arrow IPC message plus Spark-format column stats. * @@ -153,27 +158,114 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { // Let Spark use generic ColumnVector access; our columns are heterogeneous CometVector subtypes. override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None + private def toStructType(attrs: Seq[Attribute]): StructType = + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + // Compute stats from an already-built Arrow ColumnarBatch (columns are CometVector). + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val acc = new CometCacheColumnStats(attrs) + val numRows = batch.numRows() + var c = 0 + while (c < attrs.length) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + acc.update(c, dt, isNull = true, null) + } else { + acc.update(c, dt, isNull = false, readValue(col, dt, r)) + } + r += 1 + } + c += 1 + } + acc.setRowCount(numRows) + acc.toInternalRow + } + + // Read one value in Catalyst internal form from a ColumnVector. + private def readValue(col: ColumnVector, dt: DataType, r: Int): Any = dt match { + case BooleanType => col.getBoolean(r) + case ByteType => col.getByte(r) + case ShortType => col.getShort(r) + case IntegerType | DateType => col.getInt(r) + case LongType | TimestampType => col.getLong(r) + case FloatType => col.getFloat(r) + case DoubleType => col.getDouble(r) + case d: DecimalType => col.getDecimal(r, d.precision, d.scale) + case StringType => col.getUTF8String(r) + case _ => null // BinaryType etc.: no stats bounds + } + + // Encode a single Arrow ColumnarBatch to compressed Arrow IPC bytes. + private def encodeBytes(batch: ColumnarBatch): Array[Byte] = { + val it = CometUtils.serializeBatches(Iterator.single(batch)) + val (_, cbb) = it.next() + cbb.toArray + } + + private def encode( + arrowBatches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = + arrowBatches.map { batch => + val stats = computeStats(batch, attrs) + val bytes = encodeBytes(batch) + CometCachedBatch(batch.numRows(), bytes, stats).asInstanceOf[CachedBatch] + } + override def convertInternalRowToCachedBatch( input: RDD[InternalRow], schema: Seq[Attribute], storageLevel: StorageLevel, - conf: SQLConf): RDD[CachedBatch] = ??? + conf: SQLConf): RDD[CachedBatch] = { + if (!isCometSchema(schema.map(_.dataType))) { + return fallback.convertInternalRowToCachedBatch(input, schema, storageLevel, conf) + } + val structType = toStructType(schema) + val attrs = schema + val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf).toLong + input.mapPartitions { rowIter => + val ctx = TaskContext.get() + val arrowBatches = + CometArrowConverters.rowToArrowBatchIter(rowIter, structType, maxRecords, "UTC", ctx) + encode(arrowBatches, attrs) + } + } override def convertColumnarBatchToCachedBatch( input: RDD[ColumnarBatch], schema: Seq[Attribute], storageLevel: StorageLevel, - conf: SQLConf): RDD[CachedBatch] = ??? + conf: SQLConf): RDD[CachedBatch] = { + if (!isCometSchema(schema.map(_.dataType))) { + return fallback.convertColumnarBatchToCachedBatch(input, schema, storageLevel, conf) + } + // Defensive: supportsColumnarInput returns false for Comet schemas so this is rarely + // called, but implement it correctly by converting each Spark batch to Arrow first. + val structType = toStructType(schema) + val attrs = schema + val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf) + input.mapPartitions { batchIter => + val ctx = TaskContext.get() + val arrowBatches = batchIter.flatMap { b => + CometArrowConverters.columnarBatchToArrowBatchIter(b, structType, maxRecords, "UTC", ctx) + } + encode(arrowBatches, attrs) + } + } override def convertCachedBatchToColumnarBatch( input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): RDD[ColumnarBatch] = ??? + conf: SQLConf): RDD[ColumnarBatch] = + throw new UnsupportedOperationException("read path not yet implemented") override def convertCachedBatchToInternalRow( input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): RDD[InternalRow] = ??? + conf: SQLConf): RDD[InternalRow] = + throw new UnsupportedOperationException("read path not yet implemented") } diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 7172caeff8..f701b2f9a9 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, Come import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.CometConf + class CometCachedBatchSerializerSuite extends CometTestBase { test("stats row has 5 fields per column in cachedAttributes order") { @@ -69,4 +71,26 @@ class CometCachedBatchSerializerSuite extends CometTestBase { // nested delegates to DefaultCachedBatchSerializer, which does not support columnar output assert(!ser.supportsColumnarOutput(nested)) } + + test("build path produces one CometCachedBatch per Arrow batch with stats") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "100") { + val ser = new CometCachedBatchSerializer + val df = spark.range(250).selectExpr("id", "cast(id as string) as s") + val attrs = df.queryExecution.analyzed.output + val rdd = df.queryExecution.toRdd + val cached = ser + .convertInternalRowToCachedBatch( + rdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + .collect() + assert(cached.forall(_.isInstanceOf[CometCachedBatch])) + assert(cached.map(_.numRows).sum == 250) + cached.foreach { b => + assert(b.sizeInBytes > 0) + assert(b.asInstanceOf[CometCachedBatch].stats.numFields == attrs.length * 5) + } + } + } } From bf522db868af4b3ef1edf237eaec6f3a5036fe46 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 09:14:49 -0600 Subject: [PATCH 08/20] feat: support TimestampNTZ in cache stats and harden build test Add TimestampNTZType alongside TimestampType in isCometType, readValue, CometCacheColumnStats.ordered, and CometCacheColumnStats.compare so that timestamp-without-timezone columns are cached natively. Remove the unnecessary .asInstanceOf[CachedBatch] cast from encode (Iterator covariance makes it redundant), drop the dead val attrs = schema aliases in both build methods, and add a lifecycle comment on encodeBytes documenting that stats must be computed before serialization clears the VectorSchemaRoot. Strengthen the build test to use coalesce(1) for deterministic batch count (exactly 3 batches for 250 rows at batch size 100) and assert real stat values: minimum lowerBound of the id column is 0 and nullCount is 0. --- .../comet/CometCachedBatchSerializer.scala | 21 ++++++++++--------- .../CometCachedBatchSerializerSuite.scala | 10 ++++++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index f8d0eb8969..8cec102e18 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -103,7 +103,7 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { private def ordered(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - _: DecimalType | StringType | DateType | TimestampType => + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => true case _ => false } @@ -115,7 +115,7 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { case ShortType => java.lang.Short.compare(x.asInstanceOf[Short], y.asInstanceOf[Short]) case IntegerType | DateType => java.lang.Integer.compare(x.asInstanceOf[Int], y.asInstanceOf[Int]) - case LongType | TimestampType => + case LongType | TimestampType | TimestampNTZType => java.lang.Long.compare(x.asInstanceOf[Long], y.asInstanceOf[Long]) case FloatType => java.lang.Float.compare(x.asInstanceOf[Float], y.asInstanceOf[Float]) case DoubleType => java.lang.Double.compare(x.asInstanceOf[Double], y.asInstanceOf[Double]) @@ -141,7 +141,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { private def isCometType(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - _: DecimalType | StringType | BinaryType | DateType | TimestampType => + _: DecimalType | StringType | BinaryType | DateType | TimestampType | TimestampNTZType => true // Nested/complex types are out of scope for v1; delegate to the default serializer. case _ => false @@ -190,7 +190,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { case ByteType => col.getByte(r) case ShortType => col.getShort(r) case IntegerType | DateType => col.getInt(r) - case LongType | TimestampType => col.getLong(r) + case LongType | TimestampType | TimestampNTZType => col.getLong(r) case FloatType => col.getFloat(r) case DoubleType => col.getDouble(r) case d: DecimalType => col.getDecimal(r, d.precision, d.scale) @@ -198,7 +198,10 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { case _ => null // BinaryType etc.: no stats bounds } - // Encode a single Arrow ColumnarBatch to compressed Arrow IPC bytes. + // INVARIANT: compute stats BEFORE calling this. serializeBatches internally clears the + // VectorSchemaRoot wrapping the batch's field vectors, so the batch must not be read after + // this call. The row/columnar Arrow iterators reset those vectors before producing the next + // batch, so the clear is safe as long as we never touch this batch again. private def encodeBytes(batch: ColumnarBatch): Array[Byte] = { val it = CometUtils.serializeBatches(Iterator.single(batch)) val (_, cbb) = it.next() @@ -211,7 +214,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { arrowBatches.map { batch => val stats = computeStats(batch, attrs) val bytes = encodeBytes(batch) - CometCachedBatch(batch.numRows(), bytes, stats).asInstanceOf[CachedBatch] + CometCachedBatch(batch.numRows(), bytes, stats) } override def convertInternalRowToCachedBatch( @@ -223,13 +226,12 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { return fallback.convertInternalRowToCachedBatch(input, schema, storageLevel, conf) } val structType = toStructType(schema) - val attrs = schema val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf).toLong input.mapPartitions { rowIter => val ctx = TaskContext.get() val arrowBatches = CometArrowConverters.rowToArrowBatchIter(rowIter, structType, maxRecords, "UTC", ctx) - encode(arrowBatches, attrs) + encode(arrowBatches, schema) } } @@ -244,14 +246,13 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { // Defensive: supportsColumnarInput returns false for Comet schemas so this is rarely // called, but implement it correctly by converting each Spark batch to Arrow first. val structType = toStructType(schema) - val attrs = schema val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf) input.mapPartitions { batchIter => val ctx = TaskContext.get() val arrowBatches = batchIter.flatMap { b => CometArrowConverters.columnarBatchToArrowBatchIter(b, structType, maxRecords, "UTC", ctx) } - encode(arrowBatches, attrs) + encode(arrowBatches, schema) } } diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index f701b2f9a9..3b25af398e 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -75,7 +75,8 @@ class CometCachedBatchSerializerSuite extends CometTestBase { test("build path produces one CometCachedBatch per Arrow batch with stats") { withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "100") { val ser = new CometCachedBatchSerializer - val df = spark.range(250).selectExpr("id", "cast(id as string) as s") + // coalesce(1) makes the batch chunking deterministic: 250 rows / 100 = 3 batches + val df = spark.range(250).coalesce(1).selectExpr("id", "cast(id as string) as s") val attrs = df.queryExecution.analyzed.output val rdd = df.queryExecution.toRdd val cached = ser @@ -85,12 +86,19 @@ class CometCachedBatchSerializerSuite extends CometTestBase { org.apache.spark.storage.StorageLevel.MEMORY_ONLY, spark.sessionState.conf) .collect() + assert(cached.length == 3) assert(cached.forall(_.isInstanceOf[CometCachedBatch])) assert(cached.map(_.numRows).sum == 250) cached.foreach { b => assert(b.sizeInBytes > 0) assert(b.asInstanceOf[CometCachedBatch].stats.numFields == attrs.length * 5) } + // column 0 is the bigint id; verify real (non-null) stats were computed + val statRows = cached.map(_.asInstanceOf[CometCachedBatch].stats) + // lowerBound of col 0 lives at field 0 (LongType); min across batches must be 0 + assert(statRows.map(_.getLong(0)).min == 0L) + // nullCount of col 0 lives at field 2; range() has no nulls + assert(statRows.forall(_.getInt(2) == 0)) } } } From 29cb5116d4db64348e179b74d18d80bc15649e98 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 09:25:11 -0600 Subject: [PATCH 09/20] feat: decode Comet cached batches with column pruning and row fallback --- .../comet/CometCachedBatchSerializer.scala | 84 ++++++++++++++++++- .../CometCachedBatchSerializerSuite.scala | 41 +++++++++ 2 files changed, 121 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index 8cec102e18..b4b2492f59 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.comet +import java.nio.ByteBuffer + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,6 +35,7 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.ByteArray import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.comet.CometConf @@ -256,17 +259,90 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { } } + // Map selected attributes to their column indices within cacheAttributes by exprId. + private def selectedIndices( + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute]): Array[Int] = { + val byId = cacheAttributes.map(_.exprId).zipWithIndex.toMap + selectedAttributes.map(a => byId(a.exprId)).toArray + } + + // Returns true if indices is exactly [0, 1, 2, ..., n-1] (identity projection). + private def isIdentityProjection(indices: Array[Int], numCols: Int): Boolean = { + if (indices.length != numCols) return false + var i = 0 + while (i < indices.length) { + if (indices(i) != i) return false + i += 1 + } + true + } + + // Decode one CometCachedBatch into a ColumnarBatch projected to the selected columns. + private def decodeOne(b: CometCachedBatch, indices: Array[Int]): Iterator[ColumnarBatch] = { + val chunked = new ChunkedByteBuffer(ByteBuffer.wrap(b.bytes)) + CometUtils.decodeBatches(chunked, "CometCachedBatch").map { full => + if (isIdentityProjection(indices, full.numCols())) { + full + } else { + val cols = indices.map(full.column) + new ColumnarBatch(cols, full.numRows()) + } + } + } + + // Version-safe conversion of a ColumnarBatch's java row iterator to copied Scala InternalRows. + private def rowsOf(batch: ColumnarBatch): Iterator[InternalRow] = { + val it = batch.rowIterator() + new Iterator[InternalRow] { + override def hasNext: Boolean = it.hasNext + override def next(): InternalRow = it.next().copy() + } + } + override def convertCachedBatchToColumnarBatch( input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): RDD[ColumnarBatch] = - throw new UnsupportedOperationException("read path not yet implemented") + conf: SQLConf): RDD[ColumnarBatch] = { + if (!isCometSchema(cacheAttributes.map(_.dataType))) { + return fallback.convertCachedBatchToColumnarBatch( + input, + cacheAttributes, + selectedAttributes, + conf) + } + val indices = selectedIndices(cacheAttributes, selectedAttributes) + input.mapPartitions { batchIter => + batchIter.flatMap { + case b: CometCachedBatch => decodeOne(b, indices) + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch but got ${other.getClass.getName}") + } + } + } override def convertCachedBatchToInternalRow( input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute], - conf: SQLConf): RDD[InternalRow] = - throw new UnsupportedOperationException("read path not yet implemented") + conf: SQLConf): RDD[InternalRow] = { + if (!isCometSchema(cacheAttributes.map(_.dataType))) { + return fallback.convertCachedBatchToInternalRow( + input, + cacheAttributes, + selectedAttributes, + conf) + } + val indices = selectedIndices(cacheAttributes, selectedAttributes) + input.mapPartitions { batchIter => + batchIter.flatMap { + case b: CometCachedBatch => decodeOne(b, indices).flatMap(rowsOf) + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch but got ${other.getClass.getName}") + } + } + } } diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 3b25af398e..dc09f5078f 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -101,4 +101,45 @@ class CometCachedBatchSerializerSuite extends CometTestBase { assert(statRows.forall(_.getInt(2) == 0)) } } + + test("round-trip: build then decode all columns matches input") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(200).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + val decodedRows = ser + .convertCachedBatchToInternalRow(cached, attrs, attrs, spark.sessionState.conf) + .map(r => (r.getLong(0), r.getUTF8String(1).toString)) + .collect() + .toSet + val expected = (0 until 200).map(i => (i.toLong, (i * 2).toString)).toSet + assert(decodedRows == expected) + } + } + + test("read path prunes to selected columns") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(200).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + // select only the string column (index 1) + val onlyS = Seq(attrs(1)) + val pruned = ser + .convertCachedBatchToInternalRow(cached, attrs, onlyS, spark.sessionState.conf) + .map(_.getUTF8String(0).toString) + .collect() + .toSet + assert(pruned == (0 until 200).map(i => (i * 2).toString).toSet) + } + } } From 4f4aa160cf3f99001ff7c2162fd2ee91f57d63bc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 09:39:19 -0600 Subject: [PATCH 10/20] test: cover columnar cache read path and clarify selectedIndices errors Add a test for convertCachedBatchToColumnarBatch exercising both the identity-projection passthrough (full columns) and pruned projection. Improve selectedIndices to throw IllegalStateException with a diagnostic message instead of a raw map key lookup. Update the isIdentityProjection comment to state both conditions explicitly. --- .../comet/CometCachedBatchSerializer.scala | 9 ++- .../CometCachedBatchSerializerSuite.scala | 64 +++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index b4b2492f59..dc73bcce94 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -264,10 +264,15 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { cacheAttributes: Seq[Attribute], selectedAttributes: Seq[Attribute]): Array[Int] = { val byId = cacheAttributes.map(_.exprId).zipWithIndex.toMap - selectedAttributes.map(a => byId(a.exprId)).toArray + selectedAttributes.map { a => + byId.getOrElse( + a.exprId, + throw new IllegalStateException( + s"Selected attribute $a (exprId ${a.exprId}) not found in cached attributes")) + }.toArray } - // Returns true if indices is exactly [0, 1, 2, ..., n-1] (identity projection). + // True when `indices` selects every column in order: length == numCols and indices(i) == i. private def isIdentityProjection(indices: Array[Int], numCols: Int): Boolean = { if (indices.length != numCols) return false var i = 0 diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index dc09f5078f..f40c24005a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -142,4 +142,68 @@ class CometCachedBatchSerializerSuite extends CometTestBase { assert(pruned == (0 until 200).map(i => (i * 2).toString).toSet) } } + + test("columnar read path: full and pruned projection") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(100).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + + // Full projection (identity passthrough): 2 columns, values match. + val fullColCounts = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, attrs, spark.sessionState.conf) + .map(_.numCols()) + .collect() + assert(fullColCounts.forall(_ == 2)) + val fullVals = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, attrs, spark.sessionState.conf) + .mapPartitions { batches => + batches.flatMap { b => + val rows = new scala.collection.mutable.ArrayBuffer[(Long, String)] + var i = 0 + while (i < b.numRows()) { + rows += ((b.column(0).getLong(i), b.column(1).getUTF8String(i).toString)) + i += 1 + } + rows.iterator + } + } + .collect() + .toSet + assert(fullVals == (0 until 100).map(i => (i.toLong, (i * 2).toString)).toSet) + + // Pruned projection: only the string column (index 1) -> 1 column, correct values. + val onlyS = Seq(attrs(1)) + val prunedColCounts = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, onlyS, spark.sessionState.conf) + .map(_.numCols()) + .collect() + assert(prunedColCounts.forall(_ == 1)) + val prunedVals = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, onlyS, spark.sessionState.conf) + .mapPartitions { batches => + batches.flatMap { b => + val rows = new scala.collection.mutable.ArrayBuffer[String] + var i = 0 + while (i < b.numRows()) { + rows += b.column(0).getUTF8String(i).toString + i += 1 + } + rows.iterator + } + } + .collect() + .toSet + assert(prunedVals == (0 until 100).map(i => (i * 2).toString).toSet) + } + } } From 3c1cd5f02e135ca77a9abcbce93f3b163d480adc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 09:49:49 -0600 Subject: [PATCH 11/20] feat: pass already-Arrow batches through CometSparkToColumnarExec without copy When CometSparkToColumnarExec receives batches that are already Arrow (all columns are CometVector, e.g. from CometCachedBatchSerializer), skip the columnarBatchToArrowBatchIter re-copy and pass the batch through directly. Adds a numPassthroughBatches metric to track when the fast-path fires. --- .../sql/comet/CometSparkToColumnarExec.scala | 36 ++++++++++++++----- .../CometCachedBatchSerializerSuite.scala | 34 +++++++++++++++++- 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index efe6a97d40..99dd802f76 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -67,7 +67,10 @@ case class CometSparkToColumnarExec(child: SparkPlan) "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), "conversionTime" -> SQLMetrics.createNanoTimingMetric( sparkContext, - "time converting Spark batches to Arrow batches")) + "time converting Spark batches to Arrow batches"), + "numPassthroughBatches" -> SQLMetrics.createMetric( + sparkContext, + "number of already-Arrow batches passed through without conversion")) // The conversion happens in next(), so wrap the call to measure time spent. private def createTimingIter( @@ -96,6 +99,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") + val numPassthroughBatches = longMetric("numPassthroughBatches") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) // Use UTC for Arrow schema timezone to match the native side, which always // deserializes Timestamp as Timestamp(Microsecond, Some("UTC")). Spark's internal @@ -111,13 +115,19 @@ case class CometSparkToColumnarExec(child: SparkPlan) .mapPartitionsInternal { sparkBatches => val arrowBatches = sparkBatches.flatMap { sparkBatch => - val context = TaskContext.get() - CometArrowConverters.columnarBatchToArrowBatchIter( - sparkBatch, - schema, - maxRecordsPerBatch, - timeZoneId, - context) + if (isAllCometVectors(sparkBatch)) { + // Already Arrow (e.g. from CometCachedBatchSerializer): pass through, no copy. + numPassthroughBatches += 1 + Iterator.single(sparkBatch) + } else { + val context = TaskContext.get() + CometArrowConverters.columnarBatchToArrowBatchIter( + sparkBatch, + schema, + maxRecordsPerBatch, + timeZoneId, + context) + } } createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) } @@ -141,6 +151,16 @@ case class CometSparkToColumnarExec(child: SparkPlan) override protected def withNewChildInternal(newChild: SparkPlan): CometSparkToColumnarExec = copy(child = newChild) + private def isAllCometVectors(batch: ColumnarBatch): Boolean = { + if (batch.numCols() == 0) return false + var i = 0 + while (i < batch.numCols()) { + if (!batch.column(i).isInstanceOf[org.apache.comet.vector.CometVector]) return false + i += 1 + } + true + } + } object CometSparkToColumnarExec extends CometSink[SparkPlan] with DataTypeSupport { diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index f40c24005a..6641bf6192 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -21,7 +21,7 @@ package org.apache.comet import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, CometCachedBatchSerializer} +import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, CometCachedBatchSerializer, CometSparkToColumnarExec} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -29,6 +29,11 @@ import org.apache.comet.CometConf class CometCachedBatchSerializerSuite extends CometTestBase { + override protected def sparkConf: org.apache.spark.SparkConf = { + super.sparkConf + .set("spark.sql.cache.serializer", "org.apache.spark.sql.comet.CometCachedBatchSerializer") + } + test("stats row has 5 fields per column in cachedAttributes order") { val a = AttributeReference("a", IntegerType, nullable = true)() val b = AttributeReference("b", StringType, nullable = true)() @@ -206,4 +211,31 @@ class CometCachedBatchSerializerSuite extends CometTestBase { assert(prunedVals == (0 until 100).map(i => (i * 2).toString).toSet) } } + + test("cached scan passes already-Arrow batches through CometSparkToColumnarExec") { + withSQLConf( + org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value") + .createOrReplaceTempView("comet_cache_c1") + spark.catalog.cacheTable("comet_cache_c1") + try { + // groupBy forces a CometSparkToColumnarExec to appear above the cached InMemoryTableScan. + val df = spark.sql("SELECT value, count(*) FROM comet_cache_c1 GROUP BY value") + val rows = df.collect() + assert(rows.length == 8) + val s2c = collectFirst(df.queryExecution.executedPlan) { + case s: CometSparkToColumnarExec => s + } + // CometSparkToColumnarExec must appear above the cached scan and must have taken the + // passthrough fast-path (batches already Arrow, no re-copy needed). + assert(s2c.isDefined, "expected CometSparkToColumnarExec in plan over cached scan") + assert(s2c.get.metrics("numPassthroughBatches").value > 0L) + } finally { + spark.catalog.uncacheTable("comet_cache_c1") + } + } + } } From 826249f52f8d091cb83ebc537269200d86fe96c3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:05:10 -0600 Subject: [PATCH 12/20] refactor: import CometVector and clarify passthrough test/metric --- .../apache/spark/sql/comet/CometSparkToColumnarExec.scala | 5 +++-- .../org/apache/comet/CometCachedBatchSerializerSuite.scala | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index 99dd802f76..f03ba18d36 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.{CometConf, DataTypeSupport} import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.operator.CometSink +import org.apache.comet.vector.CometVector case class CometSparkToColumnarExec(child: SparkPlan) extends RowToColumnarTransition @@ -70,7 +71,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) "time converting Spark batches to Arrow batches"), "numPassthroughBatches" -> SQLMetrics.createMetric( sparkContext, - "number of already-Arrow batches passed through without conversion")) + "number of passthrough Arrow batches")) // The conversion happens in next(), so wrap the call to measure time spent. private def createTimingIter( @@ -155,7 +156,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) if (batch.numCols() == 0) return false var i = 0 while (i < batch.numCols()) { - if (!batch.column(i).isInstanceOf[org.apache.comet.vector.CometVector]) return false + if (!batch.column(i).isInstanceOf[CometVector]) return false i += 1 } true diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 6641bf6192..4e36c74755 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -29,6 +29,9 @@ import org.apache.comet.CometConf class CometCachedBatchSerializerSuite extends CometTestBase { + // spark.sql.cache.serializer is a STATIC SQL config (cannot be set via withSQLConf at + // runtime), so it must be set at session creation. This makes the whole suite use the Comet + // cache serializer; the pure-unit tests construct a serializer directly and are unaffected. override protected def sparkConf: org.apache.spark.SparkConf = { super.sparkConf .set("spark.sql.cache.serializer", "org.apache.spark.sql.comet.CometCachedBatchSerializer") @@ -215,7 +218,7 @@ class CometCachedBatchSerializerSuite extends CometTestBase { test("cached scan passes already-Arrow batches through CometSparkToColumnarExec") { withSQLConf( org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", - CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { // jvm shuffle keeps a CometSparkToColumnarExec in the plan over the cached scan spark .range(1000) .selectExpr("id as key", "id % 8 as value") From 4960cfb99876cce2bae9ae31d293c7e4742afb02 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:12:44 -0600 Subject: [PATCH 13/20] feat: install Comet cache serializer from CometDriverPlugin when enabled Wire COMET_CACHE_SERIALIZER_ENABLED to CometDriverPlugin.init so the static spark.sql.cache.serializer config is set on the SparkConf at SparkContext startup. A user-provided non-default serializer is respected and not overridden. --- .../main/scala/org/apache/spark/Plugins.scala | 24 +++++++++++++ .../org/apache/spark/CometPluginsSuite.scala | 36 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 7290ab436a..9280b57487 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} import org.apache.comet.CometSparkSessionExtensions @@ -57,6 +58,9 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl // register CometSparkSessionExtensions if it isn't already registered CometDriverPlugin.registerCometSessionExtension(sc.conf) + // Install the Comet cache serializer if requested + CometDriverPlugin.setCacheSerializerIfEnabled(sc.conf) + // Register Comet metrics CometDriverPlugin.registerCometMetrics(sc) @@ -104,6 +108,26 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } object CometDriverPlugin extends Logging { + private[spark] val CACHE_SERIALIZER_KEY = "spark.sql.cache.serializer" + private[spark] val COMET_CACHE_SERIALIZER = + "org.apache.spark.sql.comet.CometCachedBatchSerializer" + private[spark] val DEFAULT_CACHE_SERIALIZER = + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer" + + /** + * If the Comet cache serializer is enabled, install it as Spark's cache serializer. This is a + * static SQL config, so it must be set on the SparkConf before the session is created. A + * user-provided non-default serializer is respected and not overridden. + */ + private[spark] def setCacheSerializerIfEnabled(conf: SparkConf): Unit = { + if (conf.getBoolean(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, defaultValue = false)) { + val existing = conf.get(CACHE_SERIALIZER_KEY, "") + if (existing.isEmpty || existing == DEFAULT_CACHE_SERIALIZER) { + conf.set(CACHE_SERIALIZER_KEY, COMET_CACHE_SERIALIZER) + } + } + } + def registerCometMetrics(sc: SparkContext): Unit = { if (sc.getConf.getBoolean( COMET_METRICS_ENABLED.key, diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index fa5f368e33..85cb1e53df 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -24,6 +24,8 @@ import java.io.File import org.apache.spark.sql.{CometTestBase, SaveMode} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.CometConf + class CometPluginsSuite extends CometTestBase { override protected def sparkConf: SparkConf = { val conf = new SparkConf() @@ -83,6 +85,40 @@ class CometPluginsSuite extends CometTestBase { } } + test("setCacheSerializerIfEnabled installs Comet serializer when enabled and unset") { + val conf = new SparkConf().set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert( + conf.get("spark.sql.cache.serializer") == + "org.apache.spark.sql.comet.CometCachedBatchSerializer") + } + + test("setCacheSerializerIfEnabled replaces the default serializer when enabled") { + val conf = new SparkConf() + .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + .set( + "spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert( + conf.get("spark.sql.cache.serializer") == + "org.apache.spark.sql.comet.CometCachedBatchSerializer") + } + + test("setCacheSerializerIfEnabled respects a user-provided serializer") { + val conf = new SparkConf() + .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + .set("spark.sql.cache.serializer", "com.example.MyCachedBatchSerializer") + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert(conf.get("spark.sql.cache.serializer") == "com.example.MyCachedBatchSerializer") + } + + test("setCacheSerializerIfEnabled does nothing when disabled") { + val conf = new SparkConf() + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert(conf.getOption("spark.sql.cache.serializer").isEmpty) + } + test("CometSource metrics are recorded") { val nativeBefore = CometSource.NATIVE_OPERATORS.getCount val queriesBefore = CometSource.QUERIES_PLANNED.getCount From bec8c30add659281bb60551419ebdd17c2ba3536 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:23:39 -0600 Subject: [PATCH 14/20] refactor: use StaticSQLConf.SPARK_CACHE_SERIALIZER and log serializer install --- .../main/scala/org/apache/spark/Plugins.scala | 17 ++++++++--------- .../org/apache/spark/CometPluginsSuite.scala | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 9280b57487..26d77c4f94 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,8 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.comet.CometConf -import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} +import org.apache.comet.CometConf.{COMET_CACHE_SERIALIZER_ENABLED, COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} import org.apache.comet.CometSparkSessionExtensions /** @@ -108,11 +107,8 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } object CometDriverPlugin extends Logging { - private[spark] val CACHE_SERIALIZER_KEY = "spark.sql.cache.serializer" private[spark] val COMET_CACHE_SERIALIZER = "org.apache.spark.sql.comet.CometCachedBatchSerializer" - private[spark] val DEFAULT_CACHE_SERIALIZER = - "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer" /** * If the Comet cache serializer is enabled, install it as Spark's cache serializer. This is a @@ -120,10 +116,13 @@ object CometDriverPlugin extends Logging { * user-provided non-default serializer is respected and not overridden. */ private[spark] def setCacheSerializerIfEnabled(conf: SparkConf): Unit = { - if (conf.getBoolean(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, defaultValue = false)) { - val existing = conf.get(CACHE_SERIALIZER_KEY, "") - if (existing.isEmpty || existing == DEFAULT_CACHE_SERIALIZER) { - conf.set(CACHE_SERIALIZER_KEY, COMET_CACHE_SERIALIZER) + if (conf.getBoolean(COMET_CACHE_SERIALIZER_ENABLED.key, defaultValue = false)) { + val key = StaticSQLConf.SPARK_CACHE_SERIALIZER.key + val default = StaticSQLConf.SPARK_CACHE_SERIALIZER.defaultValueString + val existing = conf.get(key, default) + if (existing == default) { + logInfo(s"Setting $key=$COMET_CACHE_SERIALIZER") + conf.set(key, COMET_CACHE_SERIALIZER) } } } diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index 85cb1e53df..9b00e42593 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -89,34 +89,35 @@ class CometPluginsSuite extends CometTestBase { val conf = new SparkConf().set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") CometDriverPlugin.setCacheSerializerIfEnabled(conf) assert( - conf.get("spark.sql.cache.serializer") == - "org.apache.spark.sql.comet.CometCachedBatchSerializer") + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == + CometDriverPlugin.COMET_CACHE_SERIALIZER) } test("setCacheSerializerIfEnabled replaces the default serializer when enabled") { val conf = new SparkConf() .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") .set( - "spark.sql.cache.serializer", - "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + StaticSQLConf.SPARK_CACHE_SERIALIZER.key, + StaticSQLConf.SPARK_CACHE_SERIALIZER.defaultValueString) CometDriverPlugin.setCacheSerializerIfEnabled(conf) assert( - conf.get("spark.sql.cache.serializer") == - "org.apache.spark.sql.comet.CometCachedBatchSerializer") + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == + CometDriverPlugin.COMET_CACHE_SERIALIZER) } test("setCacheSerializerIfEnabled respects a user-provided serializer") { val conf = new SparkConf() .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") - .set("spark.sql.cache.serializer", "com.example.MyCachedBatchSerializer") + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, "com.example.MyCachedBatchSerializer") CometDriverPlugin.setCacheSerializerIfEnabled(conf) - assert(conf.get("spark.sql.cache.serializer") == "com.example.MyCachedBatchSerializer") + assert( + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == "com.example.MyCachedBatchSerializer") } test("setCacheSerializerIfEnabled does nothing when disabled") { val conf = new SparkConf() CometDriverPlugin.setCacheSerializerIfEnabled(conf) - assert(conf.getOption("spark.sql.cache.serializer").isEmpty) + assert(conf.getOption(StaticSQLConf.SPARK_CACHE_SERIALIZER.key).isEmpty) } test("CometSource metrics are recorded") { From 7ba47f4366339d8e2656f7ebb2d30f531513ee84 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:31:01 -0600 Subject: [PATCH 15/20] test: end-to-end Comet cache serializer correctness, pruning, spill, delegation, ntz --- .../CometCachedBatchSerializerSuite.scala | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 4e36c74755..a617d55efb 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -19,10 +19,13 @@ package org.apache.comet +import java.time.LocalDateTime + import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, CometCachedBatchSerializer, CometSparkToColumnarExec} import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf @@ -241,4 +244,82 @@ class CometCachedBatchSerializerSuite extends CometTestBase { } } } + + test("cached query result matches uncached") { + val base = spark + .range(2000) + .selectExpr("id as k", "id % 10 as v", "cast(id as string) as s") + val expected = + base + .groupBy("v") + .count() + .orderBy("v") + .collect() + .toSeq + .map(r => (r.getLong(0), r.getLong(1))) + base.createOrReplaceTempView("comet_cache_t8") + spark.catalog.cacheTable("comet_cache_t8") + try { + val df = spark.sql("SELECT v, count(*) AS c FROM comet_cache_t8 GROUP BY v ORDER BY v") + checkSparkAnswer(df) + val actual = df.collect().toSeq.map(r => (r.getLong(0), r.getLong(1))) + assert(actual == expected) + } finally { + spark.catalog.uncacheTable("comet_cache_t8") + } + } + + test("filtered cached scan returns correct rows with stats pruning") { + spark.range(5000).selectExpr("id as k").createOrReplaceTempView("comet_cache_t8f") + spark.catalog.cacheTable("comet_cache_t8f") + try { + val df = spark.sql("SELECT k FROM comet_cache_t8f WHERE k >= 4990") + checkSparkAnswer(df) + assert(df.count() == 10) + } finally { + spark.catalog.uncacheTable("comet_cache_t8f") + } + } + + test("cached table with MEMORY_AND_DISK round-trips") { + val cachedDf = spark + .range(3000) + .selectExpr("id as k", "cast(id as string) as s") + .persist(StorageLevel.MEMORY_AND_DISK) + try { + assert(cachedDf.count() == 3000) + checkSparkAnswer(cachedDf.filter("k % 2 = 0")) + } finally { + cachedDf.unpersist() + } + } + + test("array-typed cached relation delegates to default serializer and is correct") { + val df0 = spark.range(100).selectExpr("id as k", "array(id, id + 1) as a") + df0.createOrReplaceTempView("comet_cache_t8a") + spark.catalog.cacheTable("comet_cache_t8a") + try { + val df = spark.sql("SELECT k, a FROM comet_cache_t8a WHERE k < 5 ORDER BY k") + checkSparkAnswer(df) + assert(df.count() == 5) + } finally { + spark.catalog.uncacheTable("comet_cache_t8a") + } + } + + test("timestamp_ntz cached scan is correct") { + // A Seq[LocalDateTime] maps to TimestampNTZType, which the Comet serializer supports. + val data = (0 until 50).map(i => (i.toLong, LocalDateTime.of(2020, 1, 1, 0, 0, i % 60))) + import testImplicits._ + val df0 = data.toDF("id", "ts") + df0.createOrReplaceTempView("comet_cache_ntz") + spark.catalog.cacheTable("comet_cache_ntz") + try { + val df = spark.sql("SELECT id, ts FROM comet_cache_ntz WHERE id < 10 ORDER BY id") + checkSparkAnswer(df) + assert(df.count() == 10) + } finally { + spark.catalog.uncacheTable("comet_cache_ntz") + } + } } From 9dc0f7db7c1259e84e68baf469a20be3961fb8e9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:43:50 -0600 Subject: [PATCH 16/20] test: assert cached values (not just counts) for filter and timestamp_ntz --- .../CometCachedBatchSerializerSuite.scala | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index a617d55efb..55dbaed08f 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -32,6 +32,8 @@ import org.apache.comet.CometConf class CometCachedBatchSerializerSuite extends CometTestBase { + import testImplicits._ + // spark.sql.cache.serializer is a STATIC SQL config (cannot be set via withSQLConf at // runtime), so it must be set at session creation. This makes the whole suite use the Comet // cache serializer; the pure-unit tests construct a serializer directly and are unaffected. @@ -273,9 +275,10 @@ class CometCachedBatchSerializerSuite extends CometTestBase { spark.range(5000).selectExpr("id as k").createOrReplaceTempView("comet_cache_t8f") spark.catalog.cacheTable("comet_cache_t8f") try { - val df = spark.sql("SELECT k FROM comet_cache_t8f WHERE k >= 4990") + val df = spark.sql("SELECT k FROM comet_cache_t8f WHERE k >= 4990 ORDER BY k") checkSparkAnswer(df) - assert(df.count() == 10) + val actual = df.collect().map(_.getLong(0)).toSeq + assert(actual == (4990L until 5000L).toSeq) } finally { spark.catalog.uncacheTable("comet_cache_t8f") } @@ -310,14 +313,25 @@ class CometCachedBatchSerializerSuite extends CometTestBase { test("timestamp_ntz cached scan is correct") { // A Seq[LocalDateTime] maps to TimestampNTZType, which the Comet serializer supports. val data = (0 until 50).map(i => (i.toLong, LocalDateTime.of(2020, 1, 1, 0, 0, i % 60))) - import testImplicits._ val df0 = data.toDF("id", "ts") + // Expected values from the uncached DataFrame (before caching). + val expected = df0 + .where("id < 10") + .orderBy("id") + .collect() + .map(r => (r.getLong(0), r.getAs[java.time.LocalDateTime](1))) + .toSeq df0.createOrReplaceTempView("comet_cache_ntz") spark.catalog.cacheTable("comet_cache_ntz") try { val df = spark.sql("SELECT id, ts FROM comet_cache_ntz WHERE id < 10 ORDER BY id") checkSparkAnswer(df) - assert(df.count() == 10) + val actual = df + .collect() + .map(r => (r.getLong(0), r.getAs[java.time.LocalDateTime](1))) + .toSeq + assert(actual == expected) + assert(actual.size == 10) } finally { spark.catalog.uncacheTable("comet_cache_ntz") } From 40184cea67ad9456b21e07e46462764ee4d4330f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 10:59:04 -0600 Subject: [PATCH 17/20] fix: copy UTF8String for cache stats to avoid use-after-free on string bounds In CometCachedBatchSerializer.readValue, the StringType case previously returned col.getUTF8String(r) directly, which is a view into the Arrow value buffer (backed by UTF8String.fromAddress). The stats accumulator stored these views as lowerBound/upperBound, then encodeBytes called serializeBatches which clears the VectorSchemaRoot and releases those Arrow buffers. The stored string bounds then dangled, causing SimpleMetricsCachedBatchSerializer.buildFilter to prune batches using garbage stats, resulting in missing rows on filtered cached string scans. Fix by calling .copy() to materialize a heap copy before the buffer is freed. Add regression tests: one verifying stats survive encode, one verifying an equality filter on a cached string column returns correct rows. --- .../comet/CometCachedBatchSerializer.scala | 2 +- .../CometCachedBatchSerializerSuite.scala | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index dc73bcce94..885f6ea7e0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -197,7 +197,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { case FloatType => col.getFloat(r) case DoubleType => col.getDouble(r) case d: DecimalType => col.getDecimal(r, d.precision, d.scale) - case StringType => col.getUTF8String(r) + case StringType => col.getUTF8String(r).copy() case _ => null // BinaryType etc.: no stats bounds } diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala index 55dbaed08f..e3c75f4e68 100644 --- a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -310,6 +310,57 @@ class CometCachedBatchSerializerSuite extends CometTestBase { } } + test("string column stats survive encode (no buffer use-after-free)") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "100") { + val ser = new CometCachedBatchSerializer + // zero-padded so lexicographic order is well-defined and stable + val df = spark + .range(250) + .coalesce(1) + .selectExpr("id", "lpad(cast(id as string), 5, '0') as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser + .convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + .collect() + assert(cached.length == 3) + // column 1 is the string column; its stats live at fields [5..9]: + // field 5 = lowerBound, field 6 = upperBound + cached.zipWithIndex.foreach { case (b, batchIdx) => + val stats = b.asInstanceOf[CometCachedBatch].stats + val lo = stats.getUTF8String(5).toString + val hi = stats.getUTF8String(6).toString + val start = batchIdx * 100 + val end = math.min(start + 100, 250) - 1 + assert( + lo == f"$start%05d", + s"batch $batchIdx lowerBound was '$lo', expected ${f"$start%05d"}") + assert( + hi == f"$end%05d", + s"batch $batchIdx upperBound was '$hi', expected ${f"$end%05d"}") + } + } + } + + test("filtered cached scan on a string column returns correct rows") { + spark + .range(2000) + .selectExpr("lpad(cast(id as string), 5, '0') as s") + .createOrReplaceTempView("comet_cache_str") + spark.catalog.cacheTable("comet_cache_str") + try { + val df = spark.sql("SELECT s FROM comet_cache_str WHERE s = '01999'") + checkSparkAnswer(df) + val rows = df.collect().map(_.getString(0)).toSeq + assert(rows == Seq("01999")) + } finally { + spark.catalog.uncacheTable("comet_cache_str") + } + } + test("timestamp_ntz cached scan is correct") { // A Seq[LocalDateTime] maps to TimestampNTZType, which the Comet serializer supports. val data = (0 until 50).map(i => (i.toLong, LocalDateTime.of(2020, 1, 1, 0, 0, i % 60))) From d1af3a3ee6d84fc33b34725158b9c1b467ca01dc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 11:08:22 -0600 Subject: [PATCH 18/20] chore: open as draft pull request [skip ci] From bedaa77d4886920d17f2371d84b9a5d7461f4ea0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 11:22:32 -0600 Subject: [PATCH 19/20] refactor: dedup cache decode path and precompute stat orderedness --- .../comet/CometCachedBatchSerializer.scala | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala index 885f6ea7e0..e0eec845a4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -70,6 +70,7 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { private val upper = new Array[Any](numCols) private val nulls = new Array[Int](numCols) private var rowCount = 0 + private val tracksBounds: Array[Boolean] = attributes.map(a => ordered(a.dataType)).toArray /** Update column `ordinal` with one value. `value` is in Catalyst internal form (or null). */ def update(ordinal: Int, dt: DataType, isNull: Boolean, value: Any): Unit = { @@ -77,7 +78,7 @@ class CometCacheColumnStats(attributes: Seq[Attribute]) { nulls(ordinal) += 1 return } - if (!ordered(dt)) return // leave bounds null for unsupported-stat types + if (!tracksBounds(ordinal)) return // leave bounds null for unsupported-stat types if (lower(ordinal) == null || compare(dt, value, lower(ordinal)) < 0) lower(ordinal) = value if (upper(ordinal) == null || compare(dt, value, upper(ordinal)) > 0) upper(ordinal) = value } @@ -246,8 +247,9 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { if (!isCometSchema(schema.map(_.dataType))) { return fallback.convertColumnarBatchToCachedBatch(input, schema, storageLevel, conf) } - // Defensive: supportsColumnarInput returns false for Comet schemas so this is rarely - // called, but implement it correctly by converting each Spark batch to Arrow first. + // This branch is never reached for Comet schemas: supportsColumnarInput returns false for + // them, so Spark always takes the row path above. It is only reachable for delegated + // (non-Comet) schemas that somehow bypass the fallback guard, and is implemented defensively. val structType = toStructType(schema) val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf) input.mapPartitions { batchIter => @@ -273,15 +275,8 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { } // True when `indices` selects every column in order: length == numCols and indices(i) == i. - private def isIdentityProjection(indices: Array[Int], numCols: Int): Boolean = { - if (indices.length != numCols) return false - var i = 0 - while (i < indices.length) { - if (indices(i) != i) return false - i += 1 - } - true - } + private def isIdentityProjection(indices: Array[Int], numCols: Int): Boolean = + indices.length == numCols && indices.indices.forall(i => indices(i) == i) // Decode one CometCachedBatch into a ColumnarBatch projected to the selected columns. private def decodeOne(b: CometCachedBatch, indices: Array[Int]): Iterator[ColumnarBatch] = { @@ -305,6 +300,18 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { } } + private def decodeCometBatches( + input: RDD[CachedBatch], + indices: Array[Int]): RDD[ColumnarBatch] = + input.mapPartitions { batchIter => + batchIter.flatMap { + case b: CometCachedBatch => decodeOne(b, indices) + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch but got ${other.getClass.getName}") + } + } + override def convertCachedBatchToColumnarBatch( input: RDD[CachedBatch], cacheAttributes: Seq[Attribute], @@ -317,15 +324,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { selectedAttributes, conf) } - val indices = selectedIndices(cacheAttributes, selectedAttributes) - input.mapPartitions { batchIter => - batchIter.flatMap { - case b: CometCachedBatch => decodeOne(b, indices) - case other => - throw new IllegalStateException( - s"Expected CometCachedBatch but got ${other.getClass.getName}") - } - } + decodeCometBatches(input, selectedIndices(cacheAttributes, selectedAttributes)) } override def convertCachedBatchToInternalRow( @@ -340,14 +339,7 @@ class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { selectedAttributes, conf) } - val indices = selectedIndices(cacheAttributes, selectedAttributes) - input.mapPartitions { batchIter => - batchIter.flatMap { - case b: CometCachedBatch => decodeOne(b, indices).flatMap(rowsOf) - case other => - throw new IllegalStateException( - s"Expected CometCachedBatch but got ${other.getClass.getName}") - } - } + decodeCometBatches(input, selectedIndices(cacheAttributes, selectedAttributes)) + .flatMap(b => rowsOf(b)) } } From 920c57f862d99cdbca61902f7ed945196698c51f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 2 Jun 2026 11:40:46 -0600 Subject: [PATCH 20/20] chore: keep draft CI skipped after cleanup [skip ci]