Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, VariantVal}

/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
Expand Down Expand Up @@ -115,6 +115,14 @@ private[columnar] class VariantColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[VariantVal](buffer, VARIANT)
with NullableColumnAccessor

private[columnar] class TimestampNTZNanosColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[TimestampNanosVal](buffer, TIMESTAMP_NTZ_NANOS)
with NullableColumnAccessor

private[columnar] class TimestampLTZNanosColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[TimestampNanosVal](buffer, TIMESTAMP_LTZ_NANOS)
with NullableColumnAccessor

private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))

Expand Down Expand Up @@ -153,6 +161,8 @@ private[sql] object ColumnAccessor {
case DoubleType => new DoubleColumnAccessor(buf)
case s: StringType => new StringColumnAccessor(buf, s)
case BinaryType => new BinaryColumnAccessor(buf)
case _: TimestampNTZNanosType => new TimestampNTZNanosColumnAccessor(buf)
case _: TimestampLTZNanosType => new TimestampLTZNanosColumnAccessor(buf)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
case dt: DecimalType => new DecimalColumnAccessor(buf, dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats
private[columnar]
class VariantColumnBuilder extends ComplexColumnBuilder(new VariantColumnStats, VARIANT)

private[columnar] class TimestampNTZNanosColumnBuilder
extends ComplexColumnBuilder(new TimestampNanosColumnStats, TIMESTAMP_NTZ_NANOS)

private[columnar] class TimestampLTZNanosColumnBuilder
extends ComplexColumnBuilder(new TimestampNanosColumnStats, TIMESTAMP_LTZ_NANOS)

private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))

Expand Down Expand Up @@ -193,6 +199,8 @@ private[columnar] object ColumnBuilder {
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case VariantType => new VariantColumnBuilder
case _: TimestampNTZNanosType => new TimestampNTZNanosColumnBuilder
case _: TimestampLTZNanosType => new TimestampLTZNanosColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnBuilder(dt)
case dt: DecimalType => new DecimalColumnBuilder(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{TimestampNanosVal, UTF8String}

class ColumnStatisticsSchema(a: Attribute) extends Serializable {
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
Expand Down Expand Up @@ -326,6 +326,30 @@ private[columnar] final class IntervalColumnStats extends ColumnStats {
Array[Any](null, null, nullCount, count, sizeInBytes)
}

private[columnar] final class TimestampNanosColumnStats extends ColumnStats {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TimestampNanosColumnStats emits null/null for lower/upper (the CalendarInterval / IntervalColumnStats pattern), so cached nanosecond-timestamp columns get no batch-level partition pruning.

The same logical type at micro precision takes a different path: TimestampType/TimestampNTZType -> LongColumnBuilder -> LongColumnStats, which collects min/max. So a range filter (WHERE ts > '...') over a cached TIMESTAMP_NTZ(6) column skips non-matching batches, while the same filter over a cached TIMESTAMP_NTZ(9) column scans every batch.

TimestampNanosVal is Comparable (its total order is calendar order), and ordered non-primitive cache types already keep bounds — DecimalColumnStats collects Decimal min/max. So tracking upper/lower as TimestampNanosVal here (modeled on DecimalColumnStats rather than IntervalColumnStats) would preserve the pruning the micro path provides.

Not a correctness issue — the feature works. Is the bounds-less choice intentional (follow CalendarInterval), or worth collecting min/max so cached nanos timestamps prune like micro timestamps?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point -- collecting min/max is the right call, thanks. You're right that the bounds-less version was a regression from the micro path: TIMESTAMP_NTZ(6) prunes via LongColumnStats while TIMESTAMP_NTZ(9) scanned every batch.

Following your suggestion, TimestampNanosColumnStats now collects upper/lower as TimestampNanosVal (modeled on DecimalColumnStats rather than IntervalColumnStats), using its compareTo (which is calendar order). The pruning path is already wired for it -- TimestampNTZNanosType is an AtomicType so ExtractableLiteral extracts the literal, and PhysicalTimestampNTZNanosType defines an ordering, so the bound comparisons buildFilter generates are valid -- so cached nanos timestamps now prune like micro timestamps.

Added coverage: ColumnStatsSuite asserts the min/max bounds for both NTZ and LTZ, and PartitionBatchPruningSuite verifies a range filter over a cached nanos column reads fewer batches with in-memory partition pruning on than off (and returns the same rows as a pre-cache evaluation).

protected var upper: TimestampNanosVal = null
protected var lower: TimestampNanosVal = null

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
// TimestampNanosVal has a total order matching calendar order, so collect min/max bounds
// (like DecimalColumnStats, not IntervalColumnStats) to enable partition pruning, matching
// the micro-precision timestamp path (LongColumnStats). NTZ and LTZ share the same physical
// payload, so a single getter reads the value for both.
val value = row.getTimestampNTZNanos(ordinal)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += TimestampNanosVal.SIZE_IN_BYTES
count += 1
} else {
gatherNullStats()
}
}

override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String, VariantVal}


/**
Expand Down Expand Up @@ -815,6 +815,73 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
}
}

/**
* Used to append/extract [[TimestampNanosVal]] into/from the underlying [[ByteBuffer]] of a
* column. The on-buffer layout mirrors the UnsafeRow variable-length payload: an 8-byte
* epochMicros followed by an 8-byte word holding nanosWithinMicro (zero-extended), 16 bytes total
* (see TimestampNanosRowValues). NTZ and LTZ share this storage and differ only by physical type,
* so the two singletons below pass their own physicalType and row getter/setter.
*/
private[columnar] abstract class TIMESTAMP_NANOS(physicalType: PhysicalDataType)
extends ColumnType[TimestampNanosVal] {

override def dataType: PhysicalDataType = physicalType

override def defaultSize: Int = TimestampNanosVal.SIZE_IN_BYTES

protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal
protected def setNanos(row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit

override def getField(row: InternalRow, ordinal: Int): TimestampNanosVal = getNanos(row, ordinal)

override def setField(row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit =
setNanos(row, ordinal, value)

override def extract(buffer: ByteBuffer): TimestampNanosVal = {
val epochMicros = ByteBufferHelper.getLong(buffer)
// The nanos field is stored in a full 8-byte word (matching the UnsafeRow payload), so read a
// long and narrow it; the writer guarantees the value is in [0, 999].
val nanosWithinMicro = ByteBufferHelper.getLong(buffer).toShort
TimestampNanosVal.fromTrustedRowBytes(epochMicros, nanosWithinMicro)
}

// Copy the fixed 16-byte payload straight into the UnsafeRow, like CALENDAR_INTERVAL.
override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
row match {
case mutable: MutableUnsafeRow =>
val cursor = buffer.position()
buffer.position(cursor + defaultSize)
mutable.writer.write(ordinal, buffer.array(),
buffer.arrayOffset() + cursor, defaultSize)
case _ =>
setField(row, ordinal, extract(buffer))
}
}

override def append(v: TimestampNanosVal, buffer: ByteBuffer): Unit = {
ByteBufferHelper.putLong(buffer, v.epochMicros)
ByteBufferHelper.putLong(buffer, v.nanosWithinMicro.toLong)
}
}

private[columnar] object TIMESTAMP_NTZ_NANOS
extends TIMESTAMP_NANOS(PhysicalTimestampNTZNanosType) {
override protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal =
row.getTimestampNTZNanos(ordinal)
override protected def setNanos(
row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit =
row.setTimestampNTZNanos(ordinal, value)
}

private[columnar] object TIMESTAMP_LTZ_NANOS
extends TIMESTAMP_NANOS(PhysicalTimestampLTZNanosType) {
override protected def getNanos(row: InternalRow, ordinal: Int): TimestampNanosVal =
row.getTimestampLTZNanos(ordinal)
override protected def setNanos(
row: InternalRow, ordinal: Int, value: TimestampNanosVal): Unit =
row.setTimestampLTZNanos(ordinal, value)
}

/**
* Used to append/extract Java VariantVals into/from the underlying [[ByteBuffer]] of a column.
*
Expand Down Expand Up @@ -876,6 +943,8 @@ private[columnar] object ColumnType {
case s: StringType => STRING(s)
case BinaryType => BINARY
case i: CalendarIntervalType => CALENDAR_INTERVAL
case _: TimestampNTZNanosType => TIMESTAMP_NTZ_NANOS
case _: TimestampLTZNanosType => TIMESTAMP_LTZ_NANOS
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
case dt: DecimalType => LARGE_DECIMAL(dt)
case arr: ArrayType => ARRAY(PhysicalArrayType(arr.elementType, arr.containsNull))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case VariantType => classOf[VariantColumnAccessor].getName
case _: TimestampNTZNanosType => classOf[TimestampNTZNanosColumnAccessor].getName
case _: TimestampLTZNanosType => classOf[TimestampLTZNanosColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
classOf[CompactDecimalColumnAccessor].getName
case dt: DecimalType => classOf[DecimalColumnAccessor].getName
Expand All @@ -102,7 +104,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | BinaryType | CalendarIntervalType | VariantType =>
case NullType | BinaryType | CalendarIntervalType | VariantType |
_: TimestampNTZNanosType | _: TimestampLTZNanosType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.TimestampNanosVal

class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
Expand All @@ -32,6 +33,8 @@ class ColumnStatsSuite extends SparkFunSuite {
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(null, null, 0))
testStringColumnStats(Array(null, null, 0))
testTimestampNanosColumnStats(TIMESTAMP_NTZ_NANOS, Array(null, null, 0))
testTimestampNanosColumnStats(TIMESTAMP_LTZ_NANOS, Array(null, null, 0))

def testColumnStats[T <: PhysicalDataType, U <: ColumnStats](
columnStatsClass: Class[U],
Expand Down Expand Up @@ -143,6 +146,40 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}

def testTimestampNanosColumnStats(
columnType: ColumnType[TimestampNanosVal],
initialStatistics: Array[Any]): Unit = {

val columnStatsName = classOf[TimestampNanosColumnStats].getSimpleName

test(s"$columnStatsName ($columnType): empty") {
val columnStats = new TimestampNanosColumnStats
columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"$columnStatsName ($columnType): non-empty collects min/max bounds") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._

val columnStats = new TimestampNanosColumnStats
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))

val values = rows.take(10).map(_.get(0,
ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
.asInstanceOf[TimestampNanosVal])
val ordering = Ordering.fromLessThan[TimestampNanosVal](_.compareTo(_) < 0)
val stats = columnStats.collectedStatistics

assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(TimestampNanosVal.SIZE_IN_BYTES * 10 + 4 * 10, "Wrong size in bytes")(stats(4))
}
}

def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats](
initialStatistics: Array[Any]): Unit = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ColumnTypeSuite extends SparkFunSuite {
STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 8,
BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
CALENDAR_INTERVAL -> 16)
CALENDAR_INTERVAL -> 16, TIMESTAMP_NTZ_NANOS -> 16, TIMESTAMP_LTZ_NANOS -> 16)

checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
Expand Down Expand Up @@ -113,6 +113,8 @@ class ColumnTypeSuite extends SparkFunSuite {
testColumnType(ARRAY_TYPE)
testColumnType(MAP_TYPE)
testColumnType(CALENDAR_INTERVAL)
testColumnType(TIMESTAMP_NTZ_NANOS)
testColumnType(TIMESTAMP_LTZ_NANOS)

def testNativeColumnType[T <: PhysicalDataType](columnType: NativeColumnType[T]): Unit = {
val typeName = columnType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ object ColumnarDataTypeUtils {
case PhysicalShortType => ShortType
case PhysicalBinaryType => BinaryType
case PhysicalCalendarIntervalType => CalendarIntervalType
case PhysicalTimestampNTZNanosType => TimestampNTZNanosType()
case PhysicalTimestampLTZNanosType => TimestampLTZNanosType()
case PhysicalFloatType => FloatType
case PhysicalDoubleType => DoubleType
case PhysicalStringType(collationId) => StringType(collationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String}

object ColumnarTestUtils {
def makeNullRow(length: Int): GenericInternalRow = {
Expand Down Expand Up @@ -54,6 +54,10 @@ object ColumnarTestUtils {
case BINARY => randomBytes(Random.nextInt(32))
case CALENDAR_INTERVAL =>
new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong())
case _: TIMESTAMP_NANOS =>
// nanosWithinMicro must be in [0, 999]; epochMicros can be any long.
TimestampNanosVal.fromParts(
Random.nextLong(), Random.nextInt(TimestampNanosVal.MAX_NANOS_WITHIN_MICRO + 1).toShort)
case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
case STRUCT(_) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,36 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl
}
}

test("cache nanosecond-precision timestamp types") {
// Nanosecond timestamps are non-primitive for the default cache (DefaultCachedBatchSerializer
// .supportsColumnarOutput is true only for primitive types), so they always read back through
// the row path -- the vectorized reader is not exercised, the same as for CalendarInterval,
// Variant, and Decimal.
withSQLConf(SQLConf.TIMESTAMP_NANOS_TYPES_ENABLED.key -> "true") {
Seq("TIMESTAMP_NTZ(9)", "TIMESTAMP_LTZ(9)").foreach { typeName =>
withTempView("nanos") {
// Include sub-microsecond precision and a null to exercise the full payload and null
// handling through the cache.
val df = sql(
s"""SELECT * FROM VALUES
| (cast('2020-01-01 00:00:00.123456789' as $typeName)),
| (cast('1999-12-31 23:59:59.987654321' as $typeName)),
| (cast(null as $typeName))
| as t(ts)""".stripMargin)
df.createOrReplaceTempView("nanos")
val expected = sql("SELECT ts FROM nanos").collect().toSeq

spark.catalog.cacheTable("nanos")
try {
checkAnswer(sql("SELECT ts FROM nanos"), expected)
} finally {
spark.catalog.uncacheTable("nanos")
}
}
}
}
}

test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
Expand Down
Loading