diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 6fb1f5263b518..2446e107545c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} +import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, VariantVal} /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is @@ -115,6 +115,14 @@ private[columnar] class VariantColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[VariantVal](buffer, VARIANT) with NullableColumnAccessor +private[columnar] class TimestampNTZNanosColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[TimestampNanosVal](buffer, TIMESTAMP_NTZ_NANOS) + with NullableColumnAccessor + +private[columnar] class TimestampLTZNanosColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[TimestampNanosVal](buffer, TIMESTAMP_LTZ_NANOS) + with NullableColumnAccessor + private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) @@ -153,6 +161,8 @@ private[sql] object ColumnAccessor { case DoubleType => new DoubleColumnAccessor(buf) case s: StringType => new StringColumnAccessor(buf, s) case BinaryType => new BinaryColumnAccessor(buf) + case _: TimestampNTZNanosType => new TimestampNTZNanosColumnAccessor(buf) + case _: TimestampLTZNanosType => new TimestampLTZNanosColumnAccessor(buf) case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => new CompactDecimalColumnAccessor(buf, dt) case dt: DecimalType => new DecimalColumnAccessor(buf, dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index a63569b19a018..cfd0ea005e8ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -134,6 +134,12 @@ class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats private[columnar] class VariantColumnBuilder extends ComplexColumnBuilder(new VariantColumnStats, VARIANT) +private[columnar] class TimestampNTZNanosColumnBuilder + extends ComplexColumnBuilder(new TimestampNanosColumnStats, TIMESTAMP_NTZ_NANOS) + +private[columnar] class TimestampLTZNanosColumnBuilder + extends ComplexColumnBuilder(new TimestampNanosColumnStats, TIMESTAMP_LTZ_NANOS) + private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) @@ -193,6 +199,8 @@ private[columnar] object ColumnBuilder { case BinaryType => new BinaryColumnBuilder case CalendarIntervalType => new IntervalColumnBuilder case VariantType => new VariantColumnBuilder + case _: TimestampNTZNanosType => new TimestampNTZNanosColumnBuilder + case _: TimestampLTZNanosType => new TimestampLTZNanosColumnBuilder case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => new CompactDecimalColumnBuilder(dt) case dt: DecimalType => new DecimalColumnBuilder(dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 4e4b3667fa24f..c09c94ff4201d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{TimestampNanosVal, UTF8String} class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() @@ -326,6 +326,30 @@ private[columnar] final class IntervalColumnStats extends ColumnStats { Array[Any](null, null, nullCount, count, sizeInBytes) } +private[columnar] final class TimestampNanosColumnStats extends ColumnStats { + protected var upper: TimestampNanosVal = null + protected var lower: TimestampNanosVal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + // TimestampNanosVal has a total order matching calendar order, so collect min/max bounds + // (like DecimalColumnStats, not IntervalColumnStats) to enable partition pruning, matching + // the micro-precision timestamp path (LongColumnStats). NTZ and LTZ share the same physical + // payload, so a single getter reads the value for both. + val value = row.getTimestampNTZNanos(ordinal) + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += TimestampNanosVal.SIZE_IN_BYTES + count += 1 + } else { + gatherNullStats() + } + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index df250e529e2ce..cf4309b52142b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} +import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String, VariantVal} /** @@ -815,6 +815,73 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval] } } +/** + * Used to append/extract [[TimestampNanosVal]] into/from the underlying [[ByteBuffer]] of a + * column. The on-buffer layout mirrors the UnsafeRow variable-length payload: an 8-byte + * epochMicros followed by an 8-byte word holding nanosWithinMicro (zero-extended), 16 bytes total + * (see TimestampNanosRowValues). NTZ and LTZ share this storage and differ only by physical type, + * so the two singletons below pass their own physicalType and row getter/setter. + */ +private[columnar] abstract class TIMESTAMP_NANOS(physicalType: PhysicalDataType) + extends ColumnType[TimestampNanosVal] { + + override def dataType: PhysicalDataType = physicalType + + override def defaultSize: Int = TimestampNanosVal.SIZE_IN_BYTES + + protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal + protected def setNanos(row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit + + override def getField(row: InternalRow, ordinal: Int): TimestampNanosVal = getNanos(row, ordinal) + + override def setField(row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit = + setNanos(row, ordinal, value) + + override def extract(buffer: ByteBuffer): TimestampNanosVal = { + val epochMicros = ByteBufferHelper.getLong(buffer) + // The nanos field is stored in a full 8-byte word (matching the UnsafeRow payload), so read a + // long and narrow it; the writer guarantees the value is in [0, 999]. + val nanosWithinMicro = ByteBufferHelper.getLong(buffer).toShort + TimestampNanosVal.fromTrustedRowBytes(epochMicros, nanosWithinMicro) + } + + // Copy the fixed 16-byte payload straight into the UnsafeRow, like CALENDAR_INTERVAL. + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row match { + case mutable: MutableUnsafeRow => + val cursor = buffer.position() + buffer.position(cursor + defaultSize) + mutable.writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, defaultSize) + case _ => + setField(row, ordinal, extract(buffer)) + } + } + + override def append(v: TimestampNanosVal, buffer: ByteBuffer): Unit = { + ByteBufferHelper.putLong(buffer, v.epochMicros) + ByteBufferHelper.putLong(buffer, v.nanosWithinMicro.toLong) + } +} + +private[columnar] object TIMESTAMP_NTZ_NANOS + extends TIMESTAMP_NANOS(PhysicalTimestampNTZNanosType) { + override protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal = + row.getTimestampNTZNanos(ordinal) + override protected def setNanos( + row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit = + row.setTimestampNTZNanos(ordinal, value) +} + +private[columnar] object TIMESTAMP_LTZ_NANOS + extends TIMESTAMP_NANOS(PhysicalTimestampLTZNanosType) { + override protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal = + row.getTimestampLTZNanos(ordinal) + override protected def setNanos( + row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit = + row.setTimestampLTZNanos(ordinal, value) +} + /** * Used to append/extract Java VariantVals into/from the underlying [[ByteBuffer]] of a column. * @@ -876,6 +943,8 @@ private[columnar] object ColumnType { case s: StringType => STRING(s) case BinaryType => BINARY case i: CalendarIntervalType => CALENDAR_INTERVAL + case _: TimestampNTZNanosType => TIMESTAMP_NTZ_NANOS + case _: TimestampLTZNanosType => TIMESTAMP_LTZ_NANOS case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt) case dt: DecimalType => LARGE_DECIMAL(dt) case arr: ArrayType => ARRAY(PhysicalArrayType(arr.elementType, arr.containsNull)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index dd64d92bed71e..14ab652b4f077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -90,6 +90,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case BinaryType => classOf[BinaryColumnAccessor].getName case CalendarIntervalType => classOf[IntervalColumnAccessor].getName case VariantType => classOf[VariantColumnAccessor].getName + case _: TimestampNTZNanosType => classOf[TimestampNTZNanosColumnAccessor].getName + case _: TimestampLTZNanosType => classOf[TimestampLTZNanosColumnAccessor].getName case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => classOf[CompactDecimalColumnAccessor].getName case dt: DecimalType => classOf[DecimalColumnAccessor].getName @@ -102,7 +104,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val createCode = dt match { case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" - case NullType | BinaryType | CalendarIntervalType | VariantType => + case NullType | BinaryType | CalendarIntervalType | VariantType | + _: TimestampNTZNanosType | _: TimestampLTZNanosType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index bdb118b91fa28..ac01a03c684dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.TimestampNanosVal class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) @@ -32,6 +33,8 @@ class ColumnStatsSuite extends SparkFunSuite { testDecimalColumnStats(Array(null, null, 0)) testIntervalColumnStats(Array(null, null, 0)) testStringColumnStats(Array(null, null, 0)) + testTimestampNanosColumnStats(TIMESTAMP_NTZ_NANOS, Array(null, null, 0)) + testTimestampNanosColumnStats(TIMESTAMP_LTZ_NANOS, Array(null, null, 0)) def testColumnStats[T <: PhysicalDataType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -143,6 +146,40 @@ class ColumnStatsSuite extends SparkFunSuite { } } + def testTimestampNanosColumnStats( + columnType: ColumnType[TimestampNanosVal], + initialStatistics: Array[Any]): Unit = { + + val columnStatsName = classOf[TimestampNanosColumnStats].getSimpleName + + test(s"$columnStatsName ($columnType): empty") { + val columnStats = new TimestampNanosColumnStats + columnStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName ($columnType): non-empty collects min/max bounds") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + + val columnStats = new TimestampNanosColumnStats + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, + ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)) + .asInstanceOf[TimestampNanosVal]) + val ordering = Ordering.fromLessThan[TimestampNanosVal](_.compareTo(_) < 0) + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(TimestampNanosVal.SIZE_IN_BYTES * 10 + 4 * 10, "Wrong size in bytes")(stats(4)) + } + } + def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats]( initialStatistics: Array[Any]): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index cb97066098f20..93f2ed85e53f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -44,7 +44,7 @@ class ColumnTypeSuite extends SparkFunSuite { STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8, STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68, - CALENDAR_INTERVAL -> 16) + CALENDAR_INTERVAL -> 16, TIMESTAMP_NTZ_NANOS -> 16, TIMESTAMP_LTZ_NANOS -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -113,6 +113,8 @@ class ColumnTypeSuite extends SparkFunSuite { testColumnType(ARRAY_TYPE) testColumnType(MAP_TYPE) testColumnType(CALENDAR_INTERVAL) + testColumnType(TIMESTAMP_NTZ_NANOS) + testColumnType(TIMESTAMP_LTZ_NANOS) def testNativeColumnType[T <: PhysicalDataType](columnType: NativeColumnType[T]): Unit = { val typeName = columnType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarDataTypeUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarDataTypeUtils.scala index 018ce36eb7836..639a7c99b3ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarDataTypeUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarDataTypeUtils.scala @@ -30,6 +30,8 @@ object ColumnarDataTypeUtils { case PhysicalShortType => ShortType case PhysicalBinaryType => BinaryType case PhysicalCalendarIntervalType => CalendarIntervalType + case PhysicalTimestampNTZNanosType => TimestampNTZNanosType() + case PhysicalTimestampLTZNanosType => TimestampLTZNanosType() case PhysicalFloatType => FloatType case PhysicalDoubleType => DoubleType case PhysicalStringType(collationId) => StringType(collationId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index d08c34056f565..044277d46e93e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String} object ColumnarTestUtils { def makeNullRow(length: Int): GenericInternalRow = { @@ -54,6 +54,10 @@ object ColumnarTestUtils { case BINARY => randomBytes(Random.nextInt(32)) case CALENDAR_INTERVAL => new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong()) + case _: TIMESTAMP_NANOS => + // nanosWithinMicro must be in [0, 999]; epochMicros can be any long. + TimestampNanosVal.fromParts( + Random.nextLong(), Random.nextInt(TimestampNanosVal.MAX_NANOS_WITHIN_MICRO + 1).toShort) case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 57da12e87979a..fcf7edfcf87b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -222,6 +222,36 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl } } + test("cache nanosecond-precision timestamp types") { + // Nanosecond timestamps are non-primitive for the default cache (DefaultCachedBatchSerializer + // .supportsColumnarOutput is true only for primitive types), so they always read back through + // the row path -- the vectorized reader is not exercised, the same as for CalendarInterval, + // Variant, and Decimal. + withSQLConf(SQLConf.TIMESTAMP_NANOS_TYPES_ENABLED.key -> "true") { + Seq("TIMESTAMP_NTZ(9)", "TIMESTAMP_LTZ(9)").foreach { typeName => + withTempView("nanos") { + // Include sub-microsecond precision and a null to exercise the full payload and null + // handling through the cache. + val df = sql( + s"""SELECT * FROM VALUES + | (cast('2020-01-01 00:00:00.123456789' as $typeName)), + | (cast('1999-12-31 23:59:59.987654321' as $typeName)), + | (cast(null as $typeName)) + | as t(ts)""".stripMargin) + df.createOrReplaceTempView("nanos") + val expected = sql("SELECT ts FROM nanos").collect().toSeq + + spark.catalog.cacheTable("nanos") + try { + checkAnswer(sql("SELECT ts FROM nanos"), expected) + } finally { + spark.catalog.uncacheTable("nanos") + } + } + } + } + } + test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( sql("SELECT * FROM withEmptyParts"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 88ff51d0ff4cf..c8ba722fd74c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -212,4 +212,62 @@ class PartitionBatchPruningSuite extends SharedSparkSession with AdaptiveSparkPl s"Wrong number of read partitions: $queryExecution") } } + + test("SPARK-57735: partition pruning on cached nanosecond-timestamp column") { + withSQLConf(SQLConf.TIMESTAMP_NANOS_TYPES_ENABLED.key -> "true") { + withTempView("nanosPruning") { + // 100 monotonically increasing nanosecond timestamps in one ordered partition; with batch + // size 10 (set in beforeAll) this is 10 batches whose min/max bounds are ordered and + // non-overlapping. makeRDD with a single slice preserves element order (no shuffle), so a + // range filter can prune to just the matching batches. The timestamps differ only in the + // sub-microsecond component, so this also exercises nanosecond-precision bounds. + val rows = (1 to 100).map { k => + Tuple1(s"2020-01-01 00:00:00.${"%09d".format(k)}") + } + // Boundary chosen so only the last batch (values 91..100) qualifies. + val boundary = "cast('2020-01-01 00:00:00.000000090' as TIMESTAMP_NTZ(9))" + + // Compute the expected result BEFORE caching, so it cannot hit the cache (the CacheManager + // matches by logical plan, not by DataFrame identity, so evaluating an equivalent query + // after caching could be served from the InMemoryRelation). + val expected = sparkContext.makeRDD(rows, 1).toDF("s") + .selectExpr("cast(s as TIMESTAMP_NTZ(9)) as ts") + .where(s"ts > $boundary").orderBy("ts").collect().toSeq + assert(expected.nonEmpty && expected.size < 100, + "test boundary should select a strict, non-empty subset") + + sparkContext.makeRDD(rows, 1).toDF("s") + .selectExpr("cast(s as TIMESTAMP_NTZ(9)) as ts") + .createOrReplaceTempView("nanosPruning") + spark.catalog.cacheTable("nanosPruning") + try { + // Correctness: the cached + pruned read matches the pre-cache evaluation. + val cached = sql(s"SELECT ts FROM nanosPruning WHERE ts > $boundary ORDER BY ts") + assert(cached.collect().toSeq === expected, + "cached + pruned result must match the uncached evaluation") + + // Pruning: the same range query reads fewer batches with in-memory partition pruning on + // than off. (With bounds-less stats the counts would be equal because no batch can be + // skipped.) Comparing pruning-on vs pruning-off for the identical query avoids depending + // on the absolute batch/partition count, and mirrors the suite's + // "disable IN_MEMORY_PARTITION_PRUNING" test. + def readBatchesWithPruning(enabled: Boolean): Long = { + withSQLConf(SQLConf.IN_MEMORY_PARTITION_PRUNING.key -> enabled.toString) { + val df = sql(s"SELECT ts FROM nanosPruning WHERE ts > $boundary") + df.collect() + collect(df.queryExecution.executedPlan) { + case in: InMemoryTableScanExec => in.readBatches.value + }.head + } + } + val withoutPruning = readBatchesWithPruning(enabled = false) + val withPruning = readBatchesWithPruning(enabled = true) + assert(withPruning < withoutPruning, + s"expected pruning to read fewer batches: $withPruning (on) vs $withoutPruning (off)") + } finally { + spark.catalog.uncacheTable("nanosPruning") + } + } + } + } }