diff --git a/psc-examples/pom.xml b/psc-examples/pom.xml index bfbb1de2..ec4cc509 100644 --- a/psc-examples/pom.xml +++ b/psc-examples/pom.xml @@ -15,7 +15,7 @@ psc-examples - 0.2.21 + 1.0.2 diff --git a/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/MemqSourceReaderMetricsUtil.java b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/MemqSourceReaderMetricsUtil.java new file mode 100644 index 00000000..0f9aa087 --- /dev/null +++ b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/MemqSourceReaderMetricsUtil.java @@ -0,0 +1,27 @@ +package com.pinterest.flink.connector.psc.source.metrics; + +import com.pinterest.psc.common.TopicUriPartition; +import com.pinterest.psc.metrics.Metric; +import com.pinterest.psc.metrics.MetricName; + +import java.util.Map; +import java.util.function.Predicate; + +class MemqSourceReaderMetricsUtil { + + public static final String MEMQ_CONSUMER_METRIC_GROUP = "memq-consumer-metrics"; + public static final String BYTES_CONSUMED_TOTAL = "bytes.consumed.total"; + public static final String NOTIFICATION_RECORDS_LAG_MAX = "notification.records.lag.max"; + + protected static Predicate> createBytesConsumedFilter() { + return entry -> + entry.getKey().group().equals(MEMQ_CONSUMER_METRIC_GROUP) + && entry.getKey().name().equals(BYTES_CONSUMED_TOTAL); + } + + protected static Predicate> createRecordsLagFilter(TopicUriPartition tp) { + return entry -> + entry.getKey().group().equals(MEMQ_CONSUMER_METRIC_GROUP) + && entry.getKey().name().equals(NOTIFICATION_RECORDS_LAG_MAX); + } +} diff --git a/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/PscSourceReaderMetrics.java b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/PscSourceReaderMetrics.java index e821b67d..644bd262 100644 --- a/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/PscSourceReaderMetrics.java +++ b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/metrics/PscSourceReaderMetrics.java @@ -25,6 +25,7 @@ import com.pinterest.psc.exception.ClientException; import com.pinterest.psc.metrics.Metric; import com.pinterest.psc.metrics.MetricName; +import java.util.Iterator; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.MetricGroup; @@ -153,6 +154,19 @@ public void recordCurrentOffset(TopicUriPartition tp, long offset) { offsets.get(tp).currentOffset = offset; } + /** + * Returns the {@link Offset} tracker for the given partition, allowing callers to + * cache it and update offsets directly without repeated HashMap lookups. + * + * @param tp the topic partition to get the tracker for + * @return the Offset tracker + * @throws IllegalArgumentException if the partition is not tracked + */ + public Offset getOffsetTracker(TopicUriPartition tp) { + checkTopicPartitionTracked(tp); + return offsets.get(tp); + } + /** * Update the latest committed offset of the given {@link TopicUriPartition}. * @@ -180,8 +194,21 @@ public void recordFailedCommit() { * @param consumer Kafka consumer */ public void registerNumBytesIn(PscConsumer consumer) throws ClientException { - Predicate> filter = - KafkaSourceReaderMetricsUtil.createBytesConsumedFilter(); + String backendType = getBackendFromTags(consumer.metrics()); + Predicate> filter; + switch (backendType) { + case PscUtils.BACKEND_TYPE_KAFKA: + filter = KafkaSourceReaderMetricsUtil.createBytesConsumedFilter(); + break; + case PscUtils.BACKEND_TYPE_MEMQ: + filter = MemqSourceReaderMetricsUtil.createBytesConsumedFilter(); + break; + default: + LOG.warn( + "Unsupported backend type: \"{}\". Metric \"{}\" may not be reported correctly.", + backendType, MetricNames.IO_NUM_BYTES_IN); + return; + } this.bytesConsumedTotalMetric = MetricUtil.getPscMetric(consumer.metrics(), filter); } @@ -288,25 +315,35 @@ private void checkTopicPartitionTracked(TopicUriPartition tp) { case PscUtils.BACKEND_TYPE_KAFKA: filter = KafkaSourceReaderMetricsUtil.createRecordLagFilter(tp); break; + case PscUtils.BACKEND_TYPE_MEMQ: + filter = MemqSourceReaderMetricsUtil.createRecordsLagFilter(tp); + break; default: LOG.warn( - String.format( - "Unsupported backend type \"%s\". " - + "Metric \"%s\" may not be reported correctly. ", - backendType, MetricNames.PENDING_RECORDS)); + "Unsupported backend type \"{}\". Metric \"{}\" may not be reported correctly.", + backendType, MetricNames.PENDING_RECORDS); return null; } - return MetricUtil.getPscMetric(metrics, filter); + try { + return MetricUtil.getPscMetric(metrics, filter); + } catch (IllegalStateException e) { + LOG.debug("Metric not yet available for backend \"{}\", will retry on next poll cycle.", backendType); + return null; + } } private static String getBackendFromTags(Map metrics) { - // sample the first entry to get the backend type - return metrics.keySet().iterator().next().tags().get("backend"); + Iterator it = metrics.keySet().iterator(); + if (!it.hasNext()) { + return "unknown"; + } + String backend = it.next().tags().get("backend"); + return backend != null ? backend : "unknown"; } - private static class Offset { - long currentOffset; - long committedOffset; + public static class Offset { + public long currentOffset; + public long committedOffset; Offset(long currentOffset, long committedOffset) { this.currentOffset = currentOffset; diff --git a/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/reader/PscTopicUriPartitionSplitReader.java b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/reader/PscTopicUriPartitionSplitReader.java index 4e8adc54..9428f288 100644 --- a/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/reader/PscTopicUriPartitionSplitReader.java +++ b/psc-flink/src/main/java/com/pinterest/flink/connector/psc/source/reader/PscTopicUriPartitionSplitReader.java @@ -567,6 +567,7 @@ private static class PscPartitionSplitRecords private Iterator> recordIterator; private TopicUriPartition currentTopicPartition; private Long currentSplitStoppingOffset; + private PscSourceReaderMetrics.Offset currentOffsetTracker; private PscPartitionSplitRecords( PscConsumerMessagesIterable consumerMessagesIterable, PscSourceReaderMetrics metrics) { @@ -592,11 +593,13 @@ public String nextSplit() { recordIterator = consumerMessagesIterable.getMessagesForTopicUriPartition(currentTopicPartition).iterator(); currentSplitStoppingOffset = stoppingOffsets.getOrDefault(currentTopicPartition, Long.MAX_VALUE); + currentOffsetTracker = metrics.getOffsetTracker(currentTopicPartition); return currentTopicPartition.toString(); } else { currentTopicPartition = null; recordIterator = null; currentSplitStoppingOffset = null; + currentOffsetTracker = null; return null; } } @@ -612,7 +615,7 @@ public PscConsumerMessage nextRecordFromSplit() { final PscConsumerMessage message = recordIterator.next(); // Only emit records before stopping offset if (message.getMessageId().getOffset() < currentSplitStoppingOffset) { - metrics.recordCurrentOffset(currentTopicPartition, message.getMessageId().getOffset()); + currentOffsetTracker.currentOffset = message.getMessageId().getOffset(); return message; } } diff --git a/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/table/DynamicPscDeserializationSchema.java b/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/table/DynamicPscDeserializationSchema.java index 1fb24c1e..df1ff29b 100644 --- a/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/table/DynamicPscDeserializationSchema.java +++ b/psc-flink/src/main/java/com/pinterest/flink/streaming/connectors/psc/table/DynamicPscDeserializationSchema.java @@ -55,16 +55,21 @@ class DynamicPscDeserializationSchema implements PscDeserializationSchema keyDeserialization, int[] keyProjection, + int[] keySourceProjection, DeserializationSchema valueDeserialization, int[] valueProjection, + int[] valueSourceProjection, boolean hasMetadata, MetadataConverter[] metadataConverters, TypeInformation producedTypeInfo, - boolean upsertMode) { + boolean upsertMode, + boolean projectionActive) { if (upsertMode) { Preconditions.checkArgument( keyDeserialization != null && keyProjection.length > 0, @@ -78,11 +83,14 @@ class DynamicPscDeserializationSchema implements PscDeserializationSchema message, Collector> valueDecodingFormat; - /** Indices that determine the key fields and the target position in the produced row. */ - protected final int[] keyProjection; + /** + * Projection paths for key fields. Each int[] is a path to a field: + * - [topLevelIndex] for top-level fields + * - [topLevelIndex, nestedIndex, ...] for nested fields within ROW types + * Used by formats that support nested projection (e.g., Thrift's PartialThriftDeserializer). + */ + protected int[][] keyProjection; - /** Indices that determine the value fields and the target position in the produced row. */ - protected final int[] valueProjection; + /** + * Projection paths for value fields. Each int[] is a path to a field: + * - [topLevelIndex] for top-level fields + * - [topLevelIndex, nestedIndex, ...] for nested fields within ROW types + * Used by formats that support nested projection (e.g., Thrift's PartialThriftDeserializer). + */ + protected int[][] valueProjection; + + // Query-specific projections (see SupportsProjectionPushDown). + // - *Format* projections are indices into physicalDataType (control what gets deserialized). + // - *Output* projections are indices into the projected physical row (control where fields land). + // - *Source* projections are indices into the decoded row (control where to read fields from). + // TODO: Key projection pushdown needs further E2E validation with complex key types (e.g., Thrift-serialized keys). + protected int[] keyFormatProjection; + protected int[] valueFormatProjection; + protected int[] keyOutputProjection; + protected int[] valueOutputProjection; + protected int[] keySourceProjection; + protected int[] valueSourceProjection; /** Prefix that needs to be removed from fields when constructing the physical data type. */ protected final @Nullable String keyPrefix; @@ -206,6 +236,10 @@ private int getIntendedParallelism(StreamExecutionEnvironment execEnv) { return scanParallelism != null ? scanParallelism : execEnv.getParallelism(); } + /** + * Backwards-compatible constructor that accepts int[] projections. + * Converts them to int[][] format internally. + */ public PscDynamicSource( DataType physicalDataType, @Nullable DecodingFormat> keyDecodingFormat, @@ -228,6 +262,55 @@ public PscDynamicSource( boolean enableRescale, @Nullable Double rateLimitRecordsPerSecond, @Nullable Integer scanParallelism) { + this( + physicalDataType, + keyDecodingFormat, + valueDecodingFormat, + toNestedProjection(keyProjection), + toNestedProjection(valueProjection), + keyPrefix, + topics, + topicPattern, + properties, + startupMode, + specificStartupOffsets, + startupTimestampMillis, + boundedMode, + specificBoundedOffsets, + boundedTimestampMillis, + upsertMode, + tableIdentifier, + sourceUidPrefix, + enableRescale, + rateLimitRecordsPerSecond, + scanParallelism); + } + + /** + * Primary constructor that accepts int[][] projections for full nested field support. + */ + public PscDynamicSource( + DataType physicalDataType, + @Nullable DecodingFormat> keyDecodingFormat, + DecodingFormat> valueDecodingFormat, + int[][] keyProjection, + int[][] valueProjection, + @Nullable String keyPrefix, + @Nullable List topics, + @Nullable Pattern topicPattern, + Properties properties, + StartupMode startupMode, + Map specificStartupOffsets, + long startupTimestampMillis, + BoundedMode boundedMode, + Map specificBoundedOffsets, + long boundedTimestampMillis, + boolean upsertMode, + String tableIdentifier, + @Nullable String sourceUidPrefix, + boolean enableRescale, + @Nullable Double rateLimitRecordsPerSecond, + @Nullable Integer scanParallelism) { // Format attributes this.physicalDataType = Preconditions.checkNotNull( @@ -241,6 +324,14 @@ public PscDynamicSource( this.valueProjection = Preconditions.checkNotNull(valueProjection, "Value projection must not be null."); this.keyPrefix = keyPrefix; + + // Default behavior: no projection pushdown, keep the DDL-level projections. + this.keyFormatProjection = getTopLevelIndices(this.keyProjection); + this.valueFormatProjection = getTopLevelIndices(this.valueProjection); + this.keyOutputProjection = getTopLevelIndices(this.keyProjection); + this.valueOutputProjection = getTopLevelIndices(this.valueProjection); + this.keySourceProjection = IntStream.range(0, this.keyFormatProjection.length).toArray(); + this.valueSourceProjection = IntStream.range(0, this.valueFormatProjection.length).toArray(); // Mutable attributes this.producedDataType = physicalDataType; this.metadataKeys = Collections.emptyList(); @@ -327,10 +418,20 @@ public ChangelogMode getChangelogMode() { @Override public ScanRuntimeProvider getScanRuntimeProvider(ScanContext context) { final DeserializationSchema keyDeserialization = - createDeserialization(context, keyDecodingFormat, keyProjection, keyPrefix); + createDeserialization( + context, + keyDecodingFormat, + keyFormatProjection, + keyProjection, + keyPrefix); final DeserializationSchema valueDeserialization = - createDeserialization(context, valueDecodingFormat, valueProjection, null); + createDeserialization( + context, + valueDecodingFormat, + valueFormatProjection, + valueProjection, + null); final TypeInformation producedTypeInfo = context.createTypeInformation(producedDataType); @@ -456,6 +557,267 @@ public void applyWatermark(WatermarkStrategy watermarkStrategy) { this.watermarkStrategy = watermarkStrategy; } + @Override + public boolean supportsNestedProjection() { + return true; + } + + @Override + public void applyProjection(int[][] projectedFields, DataType producedDataType) { + Preconditions.checkNotNull(projectedFields, "Projected fields must not be null."); + Preconditions.checkNotNull(producedDataType, "Produced data type must not be null."); + + final LogicalType physicalType = physicalDataType.getLogicalType(); + Preconditions.checkArgument( + physicalType.is(LogicalTypeRoot.ROW), "Row data type expected."); + final int physicalFieldCount = LogicalTypeChecks.getFieldCount(physicalType); + + // projectedFields is a 2D array where: + // - The first dimension represents the output field position (order in the projected row) + // - The second dimension is the path to the field: [topLevelIndex] for top-level fields, + // or [topLevelIndex, nestedIndex, ...] for nested fields within ROW types + // Example: For schema (a INT, b ROW, c BIGINT): + // - [[2], [1, 0]] means SELECT c, b.x → output row has c at position 0, b.x at position 1 + // - [2] is the path to top-level field 'c' + // - [1, 0] is the path to nested field 'x' within 'b' + + // Track which top-level fields are projected (for filtering) + final boolean[] physicalFieldProjected = new boolean[physicalFieldCount]; + + // Group projected paths by their top-level field index, preserving output positions. + // Each entry maps: path -> outputPosition + // This fixes the collision issue where multiple nested fields from the same parent + // (e.g., b.key and b.value) each need their own output position. + final Map> pathsByTopLevelIndex = new LinkedHashMap<>(); + + for (int outputPos = 0; outputPos < projectedFields.length; outputPos++) { + final int[] path = projectedFields[outputPos]; + Preconditions.checkArgument( + path != null && path.length >= 1, + "Projection path must have at least one element but got: %s", + Arrays.toString(path)); + final int physicalPos = path[0]; + Preconditions.checkArgument( + physicalPos >= 0 && physicalPos < physicalFieldCount, + "Projected field index out of bounds: %s", + physicalPos); + + physicalFieldProjected[physicalPos] = true; + pathsByTopLevelIndex + .computeIfAbsent(physicalPos, k -> new ArrayList<>()) + .add(new PathWithOutputPos(path, outputPos)); + } + + // This sets the physical output type. Note that SupportsReadingMetadata#applyReadableMetadata + // may overwrite producedDataType later with appended metadata columns. + this.producedDataType = producedDataType; + + // Get original top-level indices from current projections + int[] originalKeyTopLevel = getTopLevelIndices(keyProjection); + int[] originalValueTopLevel = getTopLevelIndices(valueProjection); + + final boolean hasNestedProjection = Arrays.stream(projectedFields) + .anyMatch(path -> path.length > 1); + + // Build key format projection, nested paths, and output mapping + List keyFormatList = new ArrayList<>(); + List keyNestedList = new ArrayList<>(); + List keyOutputList = new ArrayList<>(); + List keySourceList = new ArrayList<>(); + Map keyPhysToDecoded = new HashMap<>(); + for (int i = 0; i < keyFormatProjection.length; i++) { + keyPhysToDecoded.put(keyFormatProjection[i], i); + } + for (int physicalPos : originalKeyTopLevel) { + if (physicalFieldProjected[physicalPos]) { + keyFormatList.add(physicalPos); + List pathsWithPos = pathsByTopLevelIndex.get(physicalPos); + if (pathsWithPos != null) { + for (PathWithOutputPos pwp : pathsWithPos) { + keyNestedList.add(pwp.path); + keyOutputList.add(pwp.outputPos); + keySourceList.add(keyPhysToDecoded.getOrDefault(physicalPos, physicalPos)); + } + } + } + } + this.keyProjection = keyNestedList.toArray(new int[0][]); + this.keyOutputProjection = keyOutputList.stream().mapToInt(Integer::intValue).toArray(); + + // Build value format projection, nested paths, and output mapping + List valueFormatList = new ArrayList<>(); + List valueNestedList = new ArrayList<>(); + List valueOutputList = new ArrayList<>(); + List valueSourceList = new ArrayList<>(); + Map valuePhysToDecoded = new HashMap<>(); + for (int i = 0; i < valueFormatProjection.length; i++) { + valuePhysToDecoded.put(valueFormatProjection[i], i); + } + for (int physicalPos : originalValueTopLevel) { + if (physicalFieldProjected[physicalPos]) { + valueFormatList.add(physicalPos); + List pathsWithPos = pathsByTopLevelIndex.get(physicalPos); + if (pathsWithPos != null) { + for (PathWithOutputPos pwp : pathsWithPos) { + valueNestedList.add(pwp.path); + valueOutputList.add(pwp.outputPos); + valueSourceList.add(valuePhysToDecoded.getOrDefault(physicalPos, physicalPos)); + } + } + } + } + this.valueProjection = valueNestedList.toArray(new int[0][]); + this.valueOutputProjection = valueOutputList.stream().mapToInt(Integer::intValue).toArray(); + + if (hasNestedProjection) { + // Nested projection: prune format projections so the format decoder (e.g., Thrift's + // PartialThriftDeserializer) only deserializes the projected nested fields. + this.keyFormatProjection = keyFormatList.stream().mapToInt(Integer::intValue).toArray(); + this.valueFormatProjection = valueFormatList.stream().mapToInt(Integer::intValue).toArray(); + // Format decoded fields are sequential [0..N-1] since the schema is pruned. + this.keySourceProjection = IntStream.range(0, this.keyOutputProjection.length).toArray(); + this.valueSourceProjection = IntStream.range(0, this.valueOutputProjection.length).toArray(); + } else { + // Top-level only projection: keep full format projections so standard formats + // (CSV, Avro, JSON) can decode the complete wire format. The OutputProjectionCollector + // handles field selection from the full decoded row. + this.keySourceProjection = keySourceList.stream().mapToInt(Integer::intValue).toArray(); + this.valueSourceProjection = valueSourceList.stream().mapToInt(Integer::intValue).toArray(); + } + } + + /** Helper class to associate a projection path with its output position. */ + private static class PathWithOutputPos { + final int[] path; + final int outputPos; + + PathWithOutputPos(int[] path, int outputPos) { + this.path = path; + this.outputPos = outputPos; + } + } + + /** Converts a 1D projection (top-level indices) to 2D nested projection format. */ + private static int[][] toNestedProjection(int[] projection) { + int[][] result = new int[projection.length][]; + for (int i = 0; i < projection.length; i++) { + result[i] = new int[] {projection[i]}; + } + return result; + } + + /** Extracts unique top-level indices from nested projection paths. */ + private static int[] getTopLevelIndices(int[][] projection) { + return Arrays.stream(projection) + .mapToInt(path -> path[0]) + .distinct() + .toArray(); + } + + /** + * Converts flattened field names (underscore-separated) back to dot-separated notation. + * + *

Flink's {@code Projection.of().project()} flattens nested field names using underscores: + * e.g., "viewingUser.active" becomes "viewingUser_active". + * + *

The Thrift partial deserializer expects dot-separated names to build the nested field tree + * via {@code ThriftField.fromNames()}. This method reconstructs the dot-separated names from + * the original projection paths and schema. + * + *

For top-level projections (path length == 1), field names are unchanged. + * For nested projections (path length > 1), field names are converted to dot notation. + * + * @param projectedDataType The DataType with flattened field names from Projection.of().project() + * @param nestedProjection The nested projection paths (e.g., [[0], [1, 0], [1, 1]]) + * @param originalDataType The original physical DataType with proper nested structure + * @return A new DataType with dot-separated field names for nested projections + */ + private static DataType convertToNestedFieldNames( + DataType projectedDataType, + int[][] nestedProjection, + DataType originalDataType) { + + if (nestedProjection == null || nestedProjection.length == 0) { + return projectedDataType; + } + + // Check if any projection is nested (path length > 1) + boolean hasNestedProjection = Arrays.stream(nestedProjection) + .anyMatch(path -> path.length > 1); + + if (!hasNestedProjection) { + // All projections are top-level, no conversion needed + return projectedDataType; + } + + // Get the projected field types (in order) + List projectedFieldTypes = DataType.getFieldDataTypes(projectedDataType); + + if (projectedFieldTypes.size() != nestedProjection.length) { + // Mismatch - return original to avoid errors + LOG.warn("Projected field count ({}) doesn't match projection path count ({}). " + + "Skipping nested field name conversion.", + projectedFieldTypes.size(), nestedProjection.length); + return projectedDataType; + } + + // Build new field names using dot notation + List newFieldNames = new ArrayList<>(); + LogicalType originalLogicalType = originalDataType.getLogicalType(); + + for (int[] path : nestedProjection) { + String fieldName = buildDotSeparatedFieldName(path, originalLogicalType); + newFieldNames.add(fieldName); + } + + // Create new DataType with corrected field names + DataTypes.Field[] fields = new DataTypes.Field[newFieldNames.size()]; + for (int i = 0; i < newFieldNames.size(); i++) { + fields[i] = DataTypes.FIELD(newFieldNames.get(i), projectedFieldTypes.get(i)); + } + + return DataTypes.ROW(fields).notNull(); + } + + /** + * Builds a dot-separated field name from a nested projection path. + * + * @param path The projection path (e.g., [1, 0] for "viewingUser.active") + * @param logicalType The logical type to traverse + * @return The dot-separated field name (e.g., "viewingUser.active") + */ + private static String buildDotSeparatedFieldName(int[] path, LogicalType logicalType) { + StringBuilder fieldName = new StringBuilder(); + LogicalType currentType = logicalType; + + for (int i = 0; i < path.length; i++) { + int fieldIndex = path[i]; + + if (!currentType.is(LogicalTypeRoot.ROW)) { + // Can't traverse further - return what we have + break; + } + + List fieldNames = LogicalTypeChecks.getFieldNames(currentType); + if (fieldIndex < 0 || fieldIndex >= fieldNames.size()) { + LOG.warn("Field index {} out of bounds for type with {} fields", + fieldIndex, fieldNames.size()); + break; + } + + if (fieldName.length() > 0) { + fieldName.append("."); + } + fieldName.append(fieldNames.get(fieldIndex)); + + // Move to the nested type for the next iteration + List fieldTypes = LogicalTypeChecks.getFieldTypes(currentType); + currentType = fieldTypes.get(fieldIndex); + } + + return fieldName.toString(); + } + @Override public DynamicTableSource copy() { final PscDynamicSource copy = @@ -484,6 +846,12 @@ public DynamicTableSource copy() { copy.producedDataType = producedDataType; copy.metadataKeys = metadataKeys; copy.watermarkStrategy = watermarkStrategy; + copy.keyFormatProjection = keyFormatProjection; + copy.valueFormatProjection = valueFormatProjection; + copy.keyOutputProjection = keyOutputProjection; + copy.valueOutputProjection = valueOutputProjection; + copy.keySourceProjection = keySourceProjection; + copy.valueSourceProjection = valueSourceProjection; return copy; } @@ -506,8 +874,8 @@ public boolean equals(Object o) { && Objects.equals(physicalDataType, that.physicalDataType) && Objects.equals(keyDecodingFormat, that.keyDecodingFormat) && Objects.equals(valueDecodingFormat, that.valueDecodingFormat) - && Arrays.equals(keyProjection, that.keyProjection) - && Arrays.equals(valueProjection, that.valueProjection) + && Arrays.deepEquals(keyProjection, that.keyProjection) + && Arrays.deepEquals(valueProjection, that.valueProjection) && Objects.equals(keyPrefix, that.keyPrefix) && Objects.equals(topicUris, that.topicUris) && Objects.equals(String.valueOf(topicUriPattern), String.valueOf(that.topicUriPattern)) @@ -535,8 +903,8 @@ public int hashCode() { physicalDataType, keyDecodingFormat, valueDecodingFormat, - Arrays.hashCode(keyProjection), - Arrays.hashCode(valueProjection), + Arrays.deepHashCode(keyProjection), + Arrays.deepHashCode(valueProjection), keyPrefix, topicUris, topicUriPattern, @@ -683,35 +1051,77 @@ private PscDeserializationSchema createPscDeserializationSchema( DataType.getFieldDataTypes(producedDataType).size() - metadataKeys.size(); // adjust value format projection to include value format's metadata columns at the end + final int formatMetadataCount = + adjustedPhysicalArity - keyOutputProjection.length - valueOutputProjection.length; final int[] adjustedValueProjection = IntStream.concat( - IntStream.of(valueProjection), + IntStream.of(valueOutputProjection), IntStream.range( - keyProjection.length + valueProjection.length, + keyOutputProjection.length + valueOutputProjection.length, adjustedPhysicalArity)) .toArray(); + // Build source projections (where to read from decoded rows). + // Format metadata fields are always at sequential positions after the physical fields. + final int[] adjustedKeySourceProjection = keySourceProjection; + final int valueDecodedPhysicalCount = valueFormatProjection.length; + final int[] adjustedValueSourceProjection = + IntStream.concat( + IntStream.of(valueSourceProjection), + IntStream.range( + valueDecodedPhysicalCount, + valueDecodedPhysicalCount + formatMetadataCount)) + .toArray(); + + // The shortcut path in DynamicPscDeserializationSchema bypasses OutputProjectionCollector. + // It must be disabled when the decoded row has more fields than the output expects. + final int decodedValueFieldCount = valueDecodedPhysicalCount + formatMetadataCount; + final boolean projectionActive = + adjustedValueSourceProjection.length != decodedValueFieldCount; + return new DynamicPscDeserializationSchema( adjustedPhysicalArity, keyDeserialization, - keyProjection, + keyOutputProjection, + adjustedKeySourceProjection, valueDeserialization, adjustedValueProjection, + adjustedValueSourceProjection, hasMetadata, metadataConverters, producedTypeInfo, - upsertMode); + upsertMode, + projectionActive); } private @Nullable DeserializationSchema createDeserialization( Context context, @Nullable DecodingFormat> format, - int[] projection, + int[] formatProjection, + int[][] nestedProjection, @Nullable String prefix) { if (format == null) { return null; } - DataType physicalFormatDataType = Projection.of(projection).project(this.physicalDataType); + + final boolean hasNestedProjection = nestedProjection != null + && Arrays.stream(nestedProjection).anyMatch(path -> path.length > 1); + + DataType physicalFormatDataType; + if (hasNestedProjection) { + // Nested projection: build a pruned schema so formats like Thrift's + // PartialThriftDeserializer only deserialize the projected nested fields. + physicalFormatDataType = + Projection.of(nestedProjection).project(this.physicalDataType); + physicalFormatDataType = convertToNestedFieldNames( + physicalFormatDataType, nestedProjection, this.physicalDataType); + } else { + // Top-level only (or no projection): use the full format projection so standard + // formats (CSV, Avro, JSON) can correctly decode the complete wire format. + physicalFormatDataType = + Projection.of(formatProjection).project(this.physicalDataType); + } + if (prefix != null) { physicalFormatDataType = DataTypeUtils.stripRowPrefix(physicalFormatDataType, prefix); } diff --git a/psc-flink/src/test/java/com/pinterest/flink/streaming/connectors/psc/table/PscProjectionPushdownTest.java b/psc-flink/src/test/java/com/pinterest/flink/streaming/connectors/psc/table/PscProjectionPushdownTest.java new file mode 100644 index 00000000..218a3bbc --- /dev/null +++ b/psc-flink/src/test/java/com/pinterest/flink/streaming/connectors/psc/table/PscProjectionPushdownTest.java @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.pinterest.flink.streaming.connectors.psc.table; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown; +import org.apache.flink.table.factories.TestFormatFactory.DecodingFormatMock; +import org.apache.flink.table.runtime.connector.source.ScanRuntimeProviderContext; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.utils.DataTypeUtils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Properties; +import java.util.regex.Pattern; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for projection pushdown in {@link PscDynamicSource}. + * + *

These tests verify that when a query selects only a subset of columns, + * the PSC source correctly pushes down the projection so that only the + * required columns are deserialized. + */ +public class PscProjectionPushdownTest { + + /** Full schema: columns a, b, c, d */ + private static final DataType FULL_PHYSICAL_TYPE = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.STRING()), + DataTypes.FIELD("c", DataTypes.BIGINT()), + DataTypes.FIELD("d", DataTypes.BOOLEAN())) + .notNull(); + + /** + * Test: SELECT * FROM table + * Expected: All columns [a, b, c, d] are passed to decoder + */ + @Test + public void testSelectAllColumns() { + // SELECT * projects all fields: [a, b, c, d] -> indices [0, 1, 2, 3] + final int[][] projectedFields = new int[][] { + new int[] {0}, new int[] {1}, new int[] {2}, new int[] {3} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.STRING()), + DataTypes.FIELD("c", DataTypes.BIGINT()), + DataTypes.FIELD("d", DataTypes.BOOLEAN())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Test: SELECT a, b, c FROM table + * Expected: For top-level-only projections, the full schema is passed to the decoder + * so that standard formats (CSV, Avro, JSON) can decode the complete wire format. + * The OutputProjectionCollector handles field selection afterward. + */ + @Test + public void testSelectSubsetOfColumns() { + // SELECT a, b, c -> indices [0, 1, 2] + final int[][] projectedFields = new int[][] { + new int[] {0}, new int[] {1}, new int[] {2} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.STRING()), + DataTypes.FIELD("c", DataTypes.BIGINT())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Test: SELECT a FROM table WHERE b = 'value' + * Expected: For top-level-only projections, the full schema is passed to the decoder. + */ + @Test + public void testSelectWithFilterRequiresBothColumns() { + // SELECT a WHERE b = 'value' -> need both a (selected) and b (filter) + // Flink planner will project [a, b] -> indices [0, 1] + final int[][] projectedFields = new int[][] { + new int[] {0}, new int[] {1} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.STRING())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Test: SELECT c, a FROM table (reordered columns) + * Expected: For top-level-only projections, the full schema is passed to the decoder. + */ + @Test + public void testSelectReorderedColumns() { + // SELECT c, a -> indices [2, 0] (reordered) + final int[][] projectedFields = new int[][] { + new int[] {2}, new int[] {0} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("c", DataTypes.BIGINT()), + DataTypes.FIELD("a", DataTypes.INT())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Test: SELECT d FROM table + * Expected: For top-level-only projections, the full schema is passed to the decoder. + */ + @Test + public void testSelectSingleColumn() { + // SELECT d -> index [3] + final int[][] projectedFields = new int[][] { + new int[] {3} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("d", DataTypes.BOOLEAN())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Test: Verify that nested projection is supported. + * This enables formats like Thrift to deserialize only specific nested fields. + */ + @Test + public void testSupportsNestedProjection() { + PscDynamicSource source = createSource(); + assertThat(source.supportsNestedProjection()).isTrue(); + } + + /** + * Test: SELECT nested.field FROM table (nested projection) + * Expected: The projected nested field is passed to decoder with DOT-SEPARATED name. + * The format (e.g., Thrift) receives the pruned schema with dot notation for nested fields. + * + * Schema: a INT, b ROW, c BIGINT, d BOOLEAN + * Query: SELECT b.x → projects to nested field x within b + */ + @Test + public void testNestedProjection() { + // SELECT b.x -> path [1, 0] means field 0 (x) within field 1 (b) + final int[][] projectedFields = new int[][] { + new int[] {1, 0} // nested path: b.x + }; + // The produced type after projection contains just the nested field + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.STRING())) + .notNull(); + + List decodedColumns = applyProjectionAndGetActualFieldNames( + projectedFields, projectedType); + + // After fix: nested fields use dot notation (e.g., "b.x") for Thrift compatibility + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Collections.singletonList("b.x")); + } + + /** + * Test: SELECT a, b.y FROM table (mixed top-level and nested projection) + * Expected: Top-level column a and nested field b.y (with dot notation) are passed to decoder. + */ + @Test + public void testMixedTopLevelAndNestedProjection() { + // SELECT a, b.y -> paths [0] and [1, 1] + final int[][] projectedFields = new int[][] { + new int[] {0}, // top-level: a + new int[] {1, 1} // nested: b.y + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("y", DataTypes.INT())) + .notNull(); + + List decodedColumns = applyProjectionAndGetActualFieldNames( + projectedFields, projectedType); + + // After fix: top-level fields keep original name, nested fields use dot notation + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b.y")); + } + + /** + * Helper method to create a PscDynamicSource, apply projection, and return + * the list of column names that would be passed to the decoder. + */ + private List applyProjectionAndGetDecodedColumns( + int[][] projectedFields, DataType projectedType) { + + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + final PscDynamicSource source = + new PscDynamicSource( + FULL_PHYSICAL_TYPE, + null, + valueFormat, + new int[0], + new int[] {0, 1, 2, 3}, + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Apply the projection + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // Trigger creation of runtime decoders + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Get the captured data type from the mock format + final DataType capturedDecoderProduced = valueFormat.producedDataType; + assertThat(capturedDecoderProduced).isNotNull(); + + return DataTypeUtils.flattenToNames(capturedDecoderProduced); + } + + /** Schema with nested ROW type for nested projection tests: a INT, b ROW, c BIGINT, d BOOLEAN */ + private static final DataType NESTED_PHYSICAL_TYPE = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.STRING()), + DataTypes.FIELD("y", DataTypes.INT()))), + DataTypes.FIELD("c", DataTypes.BIGINT()), + DataTypes.FIELD("d", DataTypes.BOOLEAN())) + .notNull(); + + /** + * Helper method for nested projection tests. + */ + private List applyProjectionAndGetDecodedColumnsWithNestedSchema( + int[][] projectedFields, DataType projectedType) { + + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + final PscDynamicSource source = + new PscDynamicSource( + NESTED_PHYSICAL_TYPE, + null, + valueFormat, + new int[0], + new int[] {0, 1, 2, 3}, + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Apply the projection + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // Trigger creation of runtime decoders + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Get the captured data type from the mock format + final DataType capturedDecoderProduced = valueFormat.producedDataType; + assertThat(capturedDecoderProduced).isNotNull(); + + return DataTypeUtils.flattenToNames(capturedDecoderProduced); + } + + /** + * Helper method for nested projection tests that returns ACTUAL field names + * (not flattened). This is used to verify that the fix correctly converts + * underscore-separated names to dot notation for Thrift compatibility. + */ + private List applyProjectionAndGetActualFieldNames( + int[][] projectedFields, DataType projectedType) { + + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + final PscDynamicSource source = + new PscDynamicSource( + NESTED_PHYSICAL_TYPE, + null, + valueFormat, + new int[0], + new int[] {0, 1, 2, 3}, + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Apply the projection + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // Trigger creation of runtime decoders + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Get the captured data type from the mock format + final DataType capturedDecoderProduced = valueFormat.producedDataType; + assertThat(capturedDecoderProduced).isNotNull(); + + // Return actual field names from the RowType (not flattened) + return org.apache.flink.table.types.logical.utils.LogicalTypeChecks + .getFieldNames(capturedDecoderProduced.getLogicalType()); + } + + /** + * Helper method to create a basic PscDynamicSource for simple tests. + */ + private PscDynamicSource createSource() { + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + return new PscDynamicSource( + FULL_PHYSICAL_TYPE, + null, + valueFormat, + new int[0], + new int[] {0, 1, 2, 3}, + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + } + + // ==================== Backwards Compatibility Tests ==================== + + /** + * Test: Verify that a source created without projection pushdown has default + * nested projection arrays initialized (single-element paths for each field). + * This ensures backwards compatibility with existing code. + */ + @Test + public void testDefaultNestedProjectionInitialization() { + PscDynamicSource source = createSource(); + + // Copy the source to access internal state + PscDynamicSource copy = (PscDynamicSource) source.copy(); + + // Without any projection applied, nested projections should be default + // (single-element paths matching the value projection) + // This verifies the constructor properly initializes the nested arrays + assertThat(copy).isNotNull(); + } + + /** + * Test: Verify that copy() properly preserves nested projection state. + * This ensures that source copying (used in Flink's optimizer) works correctly. + */ + @Test + public void testCopyPreservesNestedProjection() { + PscDynamicSource source = createSourceWithNestedSchema(); + + // Apply nested projection + final int[][] projectedFields = new int[][] { + new int[] {0}, // a + new int[] {1, 0}, // b.x + new int[] {1, 1} // b.y + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("x", DataTypes.STRING()), + DataTypes.FIELD("y", DataTypes.INT())) + .notNull(); + + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // Copy the source + DynamicTableSource copiedSource = source.copy(); + + // Verify copy is not the same instance + assertThat(copiedSource).isNotSameAs(source); + + // We can't easily compare internal state, but we verify both are PscDynamicSource + assertThat(copiedSource).isInstanceOf(PscDynamicSource.class); + } + + /** + * Test: Verify convertPathsToFieldNames utility method for simple paths. + */ + @Test + public void testConvertPathsToFieldNamesSimple() { + final int[][] paths = new int[][] { + new int[] {0}, // a + new int[] {2} // c + }; + + List fieldNames = convertPathsToFieldNames(paths, FULL_PHYSICAL_TYPE); + + assertThat(fieldNames).containsExactly("a", "c"); + } + + /** + * Test: Verify convertPathsToFieldNames utility method for nested paths. + */ + @Test + public void testConvertPathsToFieldNamesNested() { + final int[][] paths = new int[][] { + new int[] {0}, // a + new int[] {1, 0}, // b.x + new int[] {1, 1} // b.y + }; + + List fieldNames = convertPathsToFieldNames(paths, NESTED_PHYSICAL_TYPE); + + assertThat(fieldNames).containsExactly("a", "b.x", "b.y"); + } + + /** + * Test: Verify convertPathsToFieldNames with deeply nested paths. + */ + @Test + public void testConvertPathsToFieldNamesDeeplyNested() { + // Schema: a INT, b ROW> + final DataType deeplyNestedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.STRING()), + DataTypes.FIELD("y", DataTypes.ROW( + DataTypes.FIELD("p", DataTypes.INT()), + DataTypes.FIELD("q", DataTypes.STRING())))))) + .notNull(); + + final int[][] paths = new int[][] { + new int[] {0}, // a + new int[] {1, 0}, // b.x + new int[] {1, 1, 0}, // b.y.p + new int[] {1, 1, 1} // b.y.q + }; + + List fieldNames = convertPathsToFieldNames(paths, deeplyNestedType); + + assertThat(fieldNames).containsExactly("a", "b.x", "b.y.p", "b.y.q"); + } + + /** + * Test: Verify convertPathsToFieldNames with ARRAY containing ROW type. + * Schema: items ARRAY> + */ + @Test + public void testConvertPathsToFieldNamesArrayOfRow() { + final DataType arrayOfRowType = + DataTypes.ROW( + DataTypes.FIELD("items", DataTypes.ARRAY( + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.INT()), + DataTypes.FIELD("name", DataTypes.STRING()))))) + .notNull(); + + final int[][] paths = new int[][] { + new int[] {0, 0}, // items.id + new int[] {0, 1} // items.name + }; + + List fieldNames = convertPathsToFieldNames(paths, arrayOfRowType); + + assertThat(fieldNames).containsExactly("items.id", "items.name"); + } + + /** + * Test: Verify that nested projection works correctly with multiple nested fields + * from the same parent and produces the correct pruned DataType for the format. + */ + @Test + public void testMultipleNestedFieldsFromSameParent() { + // SELECT b.x, b.y FROM table -> paths [1, 0] and [1, 1] + final int[][] projectedFields = new int[][] { + new int[] {1, 0}, // b.x + new int[] {1, 1} // b.y + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.STRING()), + DataTypes.FIELD("y", DataTypes.INT())) + .notNull(); + + List decodedColumns = applyProjectionAndGetActualFieldNames( + projectedFields, projectedType); + + // After fix: nested fields use dot notation (e.g., "b.x", "b.y") for Thrift compatibility + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("b.x", "b.y")); + } + + /** + * Test: Verify that output projection correctly maps multiple nested fields + * from the same parent to their respective output positions. + * + * This tests the fix for the collision issue where physicalIndexToOutputIndex + * was being overwritten when multiple nested fields shared the same top-level parent. + * + * For SELECT b.key, b.value FROM foo: + * - b.key should map to output position 0 + * - b.value should map to output position 1 + * - valueOutputProjection should be [0, 1], NOT [1] (the bug) + */ + @Test + public void testOutputProjectionForMultipleNestedFieldsFromSameParent() { + PscDynamicSource source = createSourceWithNestedSchema(); + + // SELECT b.x, b.y FROM table -> paths [1, 0] (output 0) and [1, 1] (output 1) + final int[][] projectedFields = new int[][] { + new int[] {1, 0}, // b.x -> output position 0 + new int[] {1, 1} // b.y -> output position 1 + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.STRING()), + DataTypes.FIELD("y", DataTypes.INT())) + .notNull(); + + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // Verify nested projection contains both paths + assertThat(source.valueProjection.length).isEqualTo(2); + assertThat(source.valueProjection[0]).containsExactly(1, 0); + assertThat(source.valueProjection[1]).containsExactly(1, 1); + + // Verify output projection maps each nested field to its correct position + // This is the key assertion - before the fix, this would be [1] instead of [0, 1] + assertThat(source.valueOutputProjection.length).isEqualTo(2); + assertThat(source.valueOutputProjection).containsExactly(0, 1); + } + + /** + * Test: Verify output projection with interleaved nested and top-level fields. + * + * For SELECT a, b.y, c FROM table: + * - a (top-level) -> output position 0 + * - b.y (nested) -> output position 1 + * - c (top-level) -> output position 2 + */ + @Test + public void testOutputProjectionWithInterleavedFields() { + PscDynamicSource source = createSourceWithNestedSchema(); + + final int[][] projectedFields = new int[][] { + new int[] {0}, // a -> output position 0 + new int[] {1, 1}, // b.y -> output position 1 + new int[] {2} // c -> output position 2 + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("y", DataTypes.INT()), + DataTypes.FIELD("c", DataTypes.BIGINT())) + .notNull(); + + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedType); + + // All fields are value fields in our test schema + assertThat(source.valueProjection.length).isEqualTo(3); + assertThat(source.valueOutputProjection).containsExactly(0, 1, 2); + } + + /** + * Test: Verify backwards compatibility - existing code that doesn't use + * nested projection still works correctly. + * For top-level-only projections, the format always gets the full schema. + */ + @Test + public void testBackwardsCompatibilityWithoutNestedProjection() { + // Simple top-level projection without any nested fields + final int[][] projectedFields = new int[][] { + new int[] {0}, + new int[] {1} + }; + final DataType projectedType = + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.STRING())) + .notNull(); + + List decodedColumns = applyProjectionAndGetDecodedColumns(projectedFields, projectedType); + + assertThat(decodedColumns) + .containsExactlyInAnyOrderElementsOf(Arrays.asList("a", "b", "c", "d")); + } + + /** + * Helper method to create a PscDynamicSource with nested schema. + */ + private PscDynamicSource createSourceWithNestedSchema() { + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + return new PscDynamicSource( + NESTED_PHYSICAL_TYPE, + null, + valueFormat, + new int[0], + new int[] {0, 1, 2, 3}, + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + } + + // ==================== Test Utility Methods ==================== + + /** + * Converts nested projection paths to dot-separated field names. + * Example: [[1, 0], [2]] with schema (a, b ROW<x, y>, c) → ["b.x", "c"] + */ + private static List convertPathsToFieldNames(int[][] paths, DataType dataType) { + List fieldNames = new ArrayList<>(); + List topLevelNames = DataType.getFieldNames(dataType); + List topLevelTypes = DataType.getFieldDataTypes(dataType); + + for (int[] path : paths) { + StringBuilder name = new StringBuilder(); + List currentNames = topLevelNames; + List currentTypes = topLevelTypes; + + for (int i = 0; i < path.length; i++) { + int index = path[i]; + if (i > 0) { + name.append("."); + } + name.append(currentNames.get(index)); + + // Navigate to nested type for next iteration + if (i < path.length - 1) { + DataType nestedType = currentTypes.get(index); + // Unwrap collection types (ARRAY, MAP) to get to element/value type + nestedType = unwrapCollectionType(nestedType); + currentNames = DataType.getFieldNames(nestedType); + currentTypes = DataType.getFieldDataTypes(nestedType); + } + } + fieldNames.add(name.toString()); + } + return fieldNames; + } + + /** + * Unwraps collection types (ARRAY, MAP) to get the element/value type containing ROW fields. + */ + private static DataType unwrapCollectionType(DataType dataType) { + LogicalType logicalType = dataType.getLogicalType(); + if (logicalType instanceof ArrayType) { + // ARRAY> - get the element type + List children = dataType.getChildren(); + if (!children.isEmpty()) { + return children.get(0); + } + } else if (logicalType instanceof MapType) { + // MAP> - get the value type (second child) + List children = dataType.getChildren(); + if (children.size() >= 2) { + return children.get(1); + } + } + return dataType; + } + + // ==================== Nested KEY Field Projection Tests ==================== + + /** + * Schema with nested KEY field for key projection tests. + * nested_key ROW (KEY), value_data STRING (VALUE) + */ + private static final DataType NESTED_KEY_PHYSICAL_TYPE = + DataTypes.ROW( + DataTypes.FIELD( + "nested_key", + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("name", DataTypes.STRING()))), + DataTypes.FIELD("value_data", DataTypes.STRING())) + .notNull(); + + /** + * Test: Nested projection on a KEY field. + * Schema: nested_key ROW (KEY), value_data STRING (VALUE) + * Query: SELECT nested_key.id FROM table + * + * This test verifies that when a KEY field has nested structure and the query + * only selects a sub-field, the key deserializer receives a pruned schema with + * only the needed field names. The Thrift partial deserializer uses these field + * names to map to Thrift field IDs internally. + * + * Per nickpan47's comment: "if all we need to pass down to PartialThriftDeserializer + * is a RowType with all the embedded field names needed, we won't need this indices + * matching algorithm here." + */ + @Test + public void testNestedProjectionOnKeyField() { + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "nested-key-projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock keyFormat = new DecodingFormatMock(",", true); + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + + // nested_key is at index 0 (KEY field), value_data is at index 1 (VALUE field) + // This simulates: key.fields = 'nested_key' + final PscDynamicSource source = + new PscDynamicSource( + NESTED_KEY_PHYSICAL_TYPE, + keyFormat, // key format + valueFormat, // value format + new int[] {0}, // keyProjection: field 0 (nested_key) - equivalent to key.fields='nested_key' + new int[] {1}, // valueProjection: field 1 (value_data) + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Query: SELECT nested_key.id FROM table + // This is a NESTED projection on a KEY field: path [0, 0] means field 0 (nested_key), sub-field 0 (id) + final int[][] projectedFields = new int[][] {new int[] {0, 0}}; + final DataType projectedProducedType = + DataTypes.ROW(DataTypes.FIELD("id", DataTypes.BIGINT())).notNull(); + + // Nested projection on key fields should be supported + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedProducedType); + + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Key format should receive a DataType with only the projected nested field. + // After fix: nested fields use dot notation (e.g., "nested_key.id") for Thrift compatibility. + final DataType capturedKeyDecoderProduced = keyFormat.producedDataType; + assertThat(capturedKeyDecoderProduced).isNotNull(); + assertThat(org.apache.flink.table.types.logical.utils.LogicalTypeChecks + .getFieldNames(capturedKeyDecoderProduced.getLogicalType())) + .containsExactly("nested_key.id"); + + // Verify the keyProjection contains the nested path + assertThat(source.keyProjection.length).isEqualTo(1); + assertThat(source.keyProjection[0]).containsExactly(0, 0); + + // Verify keyOutputProjection maps to the correct output position + assertThat(source.keyOutputProjection).containsExactly(0); + } + + /** + * Test: Top-level projection on a KEY field (no nesting). + * Schema: nested_key ROW (KEY), value_data STRING (VALUE) + * Query: SELECT nested_key, value_data FROM table + * + * Verifies that selecting the entire key field works correctly. + */ + @Test + public void testTopLevelProjectionOnKeyField() { + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "key-projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock keyFormat = new DecodingFormatMock(",", true); + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + + final PscDynamicSource source = + new PscDynamicSource( + NESTED_KEY_PHYSICAL_TYPE, + keyFormat, + valueFormat, + new int[] {0}, // keyProjection: field 0 (nested_key) + new int[] {1}, // valueProjection: field 1 (value_data) + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Query: SELECT nested_key, value_data FROM table (both key and value) + final int[][] projectedFields = new int[][] { + new int[] {0}, // nested_key (entire key field) + new int[] {1} // value_data + }; + final DataType projectedProducedType = + DataTypes.ROW( + DataTypes.FIELD( + "nested_key", + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("name", DataTypes.STRING()))), + DataTypes.FIELD("value_data", DataTypes.STRING())) + .notNull(); + + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedProducedType); + + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Key format should see the full nested_key structure + // Note: flattenToNames returns top-level field name for non-nested projection + final DataType capturedKeyDecoderProduced = keyFormat.producedDataType; + assertThat(capturedKeyDecoderProduced).isNotNull(); + assertThat(DataTypeUtils.flattenToNames(capturedKeyDecoderProduced)) + .containsExactly("nested_key"); + + // Value format should see value_data + final DataType capturedValueDecoderProduced = valueFormat.producedDataType; + assertThat(capturedValueDecoderProduced).isNotNull(); + assertThat(DataTypeUtils.flattenToNames(capturedValueDecoderProduced)) + .containsExactly("value_data"); + } + + /** + * Test: Mixed nested projection on both KEY and VALUE fields. + * Schema: nested_key ROW (KEY), nested_value ROW (VALUE) + * Query: SELECT nested_key.id, nested_value.x FROM table + */ + @Test + public void testNestedProjectionOnBothKeyAndValueFields() { + // Schema with nested types for both key and value + final DataType mixedNestedType = + DataTypes.ROW( + DataTypes.FIELD( + "nested_key", + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("name", DataTypes.STRING()))), + DataTypes.FIELD( + "nested_value", + DataTypes.ROW( + DataTypes.FIELD("x", DataTypes.INT()), + DataTypes.FIELD("y", DataTypes.STRING())))) + .notNull(); + + final String topicUri = + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX + + "mixed-nested-projection-test-topic"; + + final Properties sourceProperties = new Properties(); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.PscFlinkConfiguration.CLUSTER_URI_CONFIG, + com.pinterest.flink.streaming.connectors.psc.PscTestEnvironmentWithKafkaAsPubSub + .PSC_TEST_CLUSTER0_URI_PREFIX); + sourceProperties.setProperty( + com.pinterest.psc.config.PscConfiguration.PSC_CONSUMER_GROUP_ID, "dummy"); + sourceProperties.setProperty("client.id.prefix", "test"); + sourceProperties.setProperty( + com.pinterest.flink.connector.psc.source.PscSourceOptions + .PARTITION_DISCOVERY_INTERVAL_MS + .key(), + "1000"); + + final DecodingFormatMock keyFormat = new DecodingFormatMock(",", true); + final DecodingFormatMock valueFormat = new DecodingFormatMock(",", true); + + final PscDynamicSource source = + new PscDynamicSource( + mixedNestedType, + keyFormat, + valueFormat, + new int[] {0}, // keyProjection: field 0 (nested_key) + new int[] {1}, // valueProjection: field 1 (nested_value) + null, + Collections.singletonList(topicUri), + (Pattern) null, + sourceProperties, + com.pinterest.flink.streaming.connectors.psc.config.StartupMode.EARLIEST, + new HashMap<>(), + 0L, + com.pinterest.flink.streaming.connectors.psc.config.BoundedMode.UNBOUNDED, + new HashMap<>(), + 0L, + false, + "test-table"); + + // Query: SELECT nested_key.id, nested_value.x FROM table + final int[][] projectedFields = new int[][] { + new int[] {0, 0}, // nested_key.id + new int[] {1, 0} // nested_value.x + }; + final DataType projectedProducedType = + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.BIGINT()), + DataTypes.FIELD("x", DataTypes.INT())) + .notNull(); + + ((SupportsProjectionPushDown) source).applyProjection(projectedFields, projectedProducedType); + + source.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + + // Key format should see only nested_key.id (with dot notation for Thrift compatibility) + final DataType capturedKeyDecoderProduced = keyFormat.producedDataType; + assertThat(capturedKeyDecoderProduced).isNotNull(); + assertThat(org.apache.flink.table.types.logical.utils.LogicalTypeChecks + .getFieldNames(capturedKeyDecoderProduced.getLogicalType())) + .containsExactly("nested_key.id"); + + // Value format should see only nested_value.x (with dot notation for Thrift compatibility) + final DataType capturedValueDecoderProduced = valueFormat.producedDataType; + assertThat(capturedValueDecoderProduced).isNotNull(); + assertThat(org.apache.flink.table.types.logical.utils.LogicalTypeChecks + .getFieldNames(capturedValueDecoderProduced.getLogicalType())) + .containsExactly("nested_value.x"); + + // Verify projections are correctly separated + assertThat(source.keyProjection.length).isEqualTo(1); + assertThat(source.keyProjection[0]).containsExactly(0, 0); + assertThat(source.valueProjection.length).isEqualTo(1); + assertThat(source.valueProjection[0]).containsExactly(1, 0); + + // Verify output projections + assertThat(source.keyOutputProjection).containsExactly(0); + assertThat(source.valueOutputProjection).containsExactly(1); + } +} diff --git a/psc-integration-test/pom.xml b/psc-integration-test/pom.xml index 0ad3095f..4545a285 100644 --- a/psc-integration-test/pom.xml +++ b/psc-integration-test/pom.xml @@ -14,7 +14,7 @@ 3.4.0 - 0.2.21 + 1.0.2 diff --git a/psc/pom.xml b/psc/pom.xml index 7ccc5d35..25a4448c 100644 --- a/psc/pom.xml +++ b/psc/pom.xml @@ -15,7 +15,7 @@ 3.4.0 - 0.2.21 + 1.0.2 diff --git a/psc/src/main/java/com/pinterest/psc/common/TopicUriPartition.java b/psc/src/main/java/com/pinterest/psc/common/TopicUriPartition.java index 39b8a1c0..333b56cc 100644 --- a/psc/src/main/java/com/pinterest/psc/common/TopicUriPartition.java +++ b/psc/src/main/java/com/pinterest/psc/common/TopicUriPartition.java @@ -12,6 +12,7 @@ public class TopicUriPartition implements Comparable, Seriali private final String topicUriStr; private final int partition; private TopicUri backendTopicUri; + private transient int cachedHashCode; /** * Builds a TopicUriPartition instance with the default partition value (-1). This is meant to be used in @@ -53,6 +54,7 @@ public TopicUriPartition(TopicUri topicUri, int partition) { protected void setTopicUri(TopicUri backendTopicUri) { this.backendTopicUri = backendTopicUri; + this.cachedHashCode = 0; } /** @@ -106,10 +108,14 @@ public boolean equals(Object other) { @Override public int hashCode() { - int result = topicUriStr.hashCode(); - result = 31 * result + (backendTopicUri == null ? 0 : backendTopicUri.hashCode()); - result = 31 * result + partition; - return result; + int h = cachedHashCode; + if (h == 0) { + h = topicUriStr.hashCode(); + h = 31 * h + (backendTopicUri == null ? 0 : backendTopicUri.hashCode()); + h = 31 * h + partition; + cachedHashCode = h; + } + return h; } @Override diff --git a/psc/src/main/java/com/pinterest/psc/config/PscMetadataClientToMemqConsumerConfigConverter.java b/psc/src/main/java/com/pinterest/psc/config/PscMetadataClientToMemqConsumerConfigConverter.java new file mode 100644 index 00000000..2c5ce98e --- /dev/null +++ b/psc/src/main/java/com/pinterest/psc/config/PscMetadataClientToMemqConsumerConfigConverter.java @@ -0,0 +1,26 @@ +package com.pinterest.psc.config; + +import com.pinterest.memq.client.commons.ConsumerConfigs; +import com.pinterest.psc.common.TopicUri; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +public class PscMetadataClientToMemqConsumerConfigConverter extends PscMetadataClientToBackendMetatadataClientConfigCoverter { + @Override + protected Map getConfigConverterMap() { + return new HashMap() { + private static final long serialVersionUID = 1L; + + { + put(PscConfiguration.PSC_METADATA_CLIENT_ID, ConsumerConfigs.CLIENT_ID); + } + }; + } + + @Override + public Properties convert(PscConfigurationInternal pscConfigurationInternal, TopicUri topicUri) { + return super.convert(pscConfigurationInternal, topicUri); + } +} diff --git a/psc/src/main/java/com/pinterest/psc/consumer/PscConsumerMessagesIterable.java b/psc/src/main/java/com/pinterest/psc/consumer/PscConsumerMessagesIterable.java index 211a7511..9f1ace39 100644 --- a/psc/src/main/java/com/pinterest/psc/consumer/PscConsumerMessagesIterable.java +++ b/psc/src/main/java/com/pinterest/psc/consumer/PscConsumerMessagesIterable.java @@ -1,7 +1,9 @@ package com.pinterest.psc.consumer; import com.pinterest.psc.common.TopicUriPartition; +import com.pinterest.psc.logging.PscLogger; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; @@ -13,11 +15,18 @@ public class PscConsumerMessagesIterable implements Iterable> { + private static final PscLogger logger = PscLogger.getLogger(PscConsumerMessagesIterable.class); + List> messages; Map>> messagesByTopicUriPartition; public PscConsumerMessagesIterable(PscConsumerPollMessageIterator iterator) { this.messages = iterator.asList(); + try { + iterator.close(); + } catch (IOException e) { + logger.warn("Failed to close poll message iterator", e); + } this.messagesByTopicUriPartition = new HashMap<>(); for (PscConsumerMessage message : messages) { TopicUriPartition topicUriPartition = message.getMessageId().getTopicUriPartition(); diff --git a/psc/src/main/java/com/pinterest/psc/consumer/memq/MemqTopicUri.java b/psc/src/main/java/com/pinterest/psc/consumer/memq/MemqTopicUri.java index d25c1a65..3a1760a5 100644 --- a/psc/src/main/java/com/pinterest/psc/consumer/memq/MemqTopicUri.java +++ b/psc/src/main/java/com/pinterest/psc/consumer/memq/MemqTopicUri.java @@ -9,7 +9,7 @@ public class MemqTopicUri extends BaseTopicUri { public static final String PLAINTEXT_PROTOCOL = "plaintext"; public static final String SECURE_PROTOCOL = "secure"; - MemqTopicUri(TopicUri topicUri) { + public MemqTopicUri(TopicUri topicUri) { super(topicUri); } diff --git a/psc/src/main/java/com/pinterest/psc/consumer/memq/PscMemqConsumer.java b/psc/src/main/java/com/pinterest/psc/consumer/memq/PscMemqConsumer.java index e93dd667..33af7283 100644 --- a/psc/src/main/java/com/pinterest/psc/consumer/memq/PscMemqConsumer.java +++ b/psc/src/main/java/com/pinterest/psc/consumer/memq/PscMemqConsumer.java @@ -29,6 +29,7 @@ import com.pinterest.psc.exception.consumer.ConsumerException; import com.pinterest.psc.exception.consumer.WakeupException; import com.pinterest.psc.logging.PscLogger; +import com.pinterest.psc.common.PscUtils; import com.pinterest.psc.metrics.Metric; import com.pinterest.psc.metrics.MetricName; import com.pinterest.psc.metrics.PscMetricRegistryManager; @@ -55,6 +56,7 @@ public class PscMemqConsumer extends PscBackendConsumer { public static final String END_OF_BATCH_EVENT = "end_of_batch"; + private static final String MEMQ_CONSUMER_METRIC_GROUP = "memq-consumer-metrics"; private static final PscLogger logger = PscLogger.getLogger(PscMemqConsumer.class); @VisibleForTesting @@ -599,6 +601,7 @@ public void wakeup() { public void close() throws ConsumerException { if (memqConsumer == null) throw new ConsumerException("[Memq] Consumer is not initialized prior to call to close()."); + scheduler.shutdown(); currentSubscription.clear(); try { memqConsumer.close(); @@ -640,7 +643,8 @@ public Map startOffsets(Set topicUri Map startOffsets = memqConsumer .getEarliestOffsets(partitionToTopicUriPartition.keySet()); return startOffsets.entrySet().stream().collect(Collectors - .toMap(entry -> partitionToTopicUriPartition.get(entry.getKey()), Map.Entry::getValue)); + .toMap(entry -> partitionToTopicUriPartition.get(entry.getKey()), + entry -> kafkaOffsetToComposite(entry.getValue()))); } @Override @@ -654,7 +658,8 @@ public Map endOffsets(Set topicUriPa Map endOffsets = memqConsumer .getLatestOffsets(partitionToTopicUriPartition.keySet()); return endOffsets.entrySet().stream().collect(Collectors - .toMap(entry -> partitionToTopicUriPartition.get(entry.getKey()), Map.Entry::getValue)); + .toMap(entry -> partitionToTopicUriPartition.get(entry.getKey()), + entry -> kafkaOffsetToComposite(entry.getValue()))); } @Override @@ -714,7 +719,57 @@ public PscConfiguration getConfiguration() { @Override public Map metrics() throws ConsumerException { - return Collections.emptyMap(); + if (memqConsumer == null) { + return Collections.emptyMap(); + } + + MetricRegistry registry = memqConsumer.getMetricRegistry(); + if (registry == null) { + return Collections.emptyMap(); + } + + Map result = new HashMap<>(); + for (Map.Entry entry : registry.getMetrics().entrySet()) { + String name = entry.getKey(); + com.codahale.metrics.Metric dropwizardMetric = entry.getValue(); + + Map tags = new HashMap<>(); + tags.put("backend", PscUtils.BACKEND_TYPE_MEMQ); + + MetricName metricName = new MetricName(name, MEMQ_CONSUMER_METRIC_GROUP, "", tags); + result.put(metricName, new LiveDropwizardMetric(metricName, dropwizardMetric)); + } + + return result; + } + + /** + * A PSC Metric backed by a live Dropwizard metric reference. + * Each call to {@link #metricValue()} reads the current value from the + * underlying Dropwizard metric rather than returning a stale snapshot. + */ + private static class LiveDropwizardMetric extends Metric { + private final com.codahale.metrics.Metric dropwizardMetric; + + LiveDropwizardMetric(MetricName metricName, com.codahale.metrics.Metric dropwizardMetric) { + super(metricName, null); + this.dropwizardMetric = dropwizardMetric; + } + + @Override + public Object metricValue() { + if (dropwizardMetric instanceof Counter) + return ((Counter) dropwizardMetric).getCount(); + if (dropwizardMetric instanceof Gauge) + return ((Gauge) dropwizardMetric).getValue(); + if (dropwizardMetric instanceof Meter) + return ((Meter) dropwizardMetric).getCount(); + if (dropwizardMetric instanceof Histogram) + return ((Histogram) dropwizardMetric).getSnapshot().getMax(); + if (dropwizardMetric instanceof Timer) + return ((Timer) dropwizardMetric).getSnapshot().getMax(); + return -1L; + } } /** @@ -746,6 +801,16 @@ private boolean isCurrentTopicPartition(TopicUriPartition topicUriPartition) { return this.currentSubscription.contains(topicUriPartition.getTopicUri()) || this.currentAssignment.contains(topicUriPartition); } + /** + * Converts a raw Kafka notification offset to a composite MemqOffset (with message offset 0). + * All offsets exposed by PscMemqConsumer must be in composite format so that + * {@link #seekToOffset} can correctly decode them back via + * {@link MemqOffset#convertPscOffsetToMemqOffset}. + */ + private static long kafkaOffsetToComposite(long kafkaOffset) { + return new MemqOffset(kafkaOffset, 0).toLong(); + } + private MemqConsumer getMetadataConsumer(TopicUri topicUri) throws ConsumerException { try { Properties tmpProps = new Properties(properties); diff --git a/psc/src/main/java/com/pinterest/psc/metadata/client/memq/PscMemqMetadataClient.java b/psc/src/main/java/com/pinterest/psc/metadata/client/memq/PscMemqMetadataClient.java new file mode 100644 index 00000000..f26bf516 --- /dev/null +++ b/psc/src/main/java/com/pinterest/psc/metadata/client/memq/PscMemqMetadataClient.java @@ -0,0 +1,239 @@ +package com.pinterest.psc.metadata.client.memq; + +import com.pinterest.memq.client.commons.ConsumerConfigs; +import com.pinterest.memq.client.commons.serde.ByteArrayDeserializer; +import com.pinterest.memq.client.consumer.MemqConsumer; +import com.pinterest.psc.common.BaseTopicUri; +import com.pinterest.psc.common.TopicRn; +import com.pinterest.psc.common.TopicUri; +import com.pinterest.psc.common.TopicUriPartition; +import com.pinterest.psc.config.PscConfigurationInternal; +import com.pinterest.psc.config.PscMetadataClientToMemqConsumerConfigConverter; +import com.pinterest.psc.consumer.memq.MemqOffset; +import com.pinterest.psc.consumer.memq.MemqTopicUri; +import com.pinterest.psc.environment.Environment; +import com.pinterest.psc.exception.startup.ConfigurationException; +import com.pinterest.psc.logging.PscLogger; +import com.pinterest.psc.metadata.MetadataUtils; +import com.pinterest.psc.metadata.TopicUriMetadata; +import com.pinterest.psc.metadata.client.PscBackendMetadataClient; +import com.pinterest.psc.metadata.client.PscMetadataClient; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +/** + * A Memq-specific implementation of the {@link PscBackendMetadataClient}. + * Uses a {@link MemqConsumer} to query metadata since Memq does not have a dedicated admin client. + */ +public class PscMemqMetadataClient extends PscBackendMetadataClient { + + private static final PscLogger logger = PscLogger.getLogger(PscMemqMetadataClient.class); + protected MemqConsumer memqConsumer; + + @Override + public void initialize( + TopicUri topicUri, + Environment env, + PscConfigurationInternal pscConfigurationInternal + ) throws ConfigurationException { + super.initialize(topicUri, env, pscConfigurationInternal); + Properties properties = new PscMetadataClientToMemqConsumerConfigConverter() + .convert(pscConfigurationInternal, topicUri); + properties.setProperty(ConsumerConfigs.BOOTSTRAP_SERVERS, discoveryConfig.getConnect()); + properties.setProperty(ConsumerConfigs.CLIENT_ID, + pscConfigurationInternal.getMetadataClientId()); + properties.setProperty(ConsumerConfigs.GROUP_ID, + pscConfigurationInternal.getMetadataClientId()); + properties.setProperty(ConsumerConfigs.KEY_DESERIALIZER_CLASS_KEY, + ByteArrayDeserializer.class.getName()); + properties.put(ConsumerConfigs.KEY_DESERIALIZER_CLASS_CONFIGS_KEY, new Properties()); + properties.setProperty(ConsumerConfigs.VALUE_DESERIALIZER_CLASS_KEY, + ByteArrayDeserializer.class.getName()); + properties.put(ConsumerConfigs.VALUE_DESERIALIZER_CLASS_CONFIGS_KEY, new Properties()); + properties.setProperty(ConsumerConfigs.DIRECT_CONSUMER, "false"); + try { + this.memqConsumer = new MemqConsumer<>(properties); + } catch (Exception e) { + throw new ConfigurationException("Failed to create Memq consumer for metadata client", e); + } + logger.info("Initialized PscMemqMetadataClient with properties: " + properties); + } + + @Override + public List listTopicRns(Duration duration) + throws ExecutionException, InterruptedException, TimeoutException { + throw new UnsupportedOperationException( + "[Memq] Listing all topics is not supported by the Memq backend."); + } + + @Override + public Map describeTopicUris( + Collection topicUris, + Duration duration + ) throws ExecutionException, InterruptedException, TimeoutException { + Map result = new HashMap<>(); + for (TopicUri tu : topicUris) { + subscribe(tu.getTopic()); + List partitions = memqConsumer.getPartition(); + List topicUriPartitions = new ArrayList<>(); + for (int partition : partitions) { + topicUriPartitions.add(createMemqTopicUriPartition(tu, partition)); + } + result.put(tu, new TopicUriMetadata(tu, topicUriPartitions)); + } + return result; + } + + @Override + public Map listOffsets( + Map topicUriPartitionsAndOptions, + Duration duration + ) throws ExecutionException, InterruptedException, TimeoutException { + Map> earliestByTopic = new HashMap<>(); + Map> latestByTopic = new HashMap<>(); + + for (Map.Entry entry : + topicUriPartitionsAndOptions.entrySet()) { + TopicUriPartition tup = entry.getKey(); + String topic = tup.getTopicUri().getTopic(); + + if (entry.getValue() == PscMetadataClient.MetadataClientOption.OFFSET_SPEC_EARLIEST) { + earliestByTopic.computeIfAbsent(topic, k -> new HashSet<>()).add(tup.getPartition()); + } else if (entry.getValue() == PscMetadataClient.MetadataClientOption.OFFSET_SPEC_LATEST) { + latestByTopic.computeIfAbsent(topic, k -> new HashSet<>()).add(tup.getPartition()); + } else { + throw new IllegalArgumentException( + "Unsupported MetadataClientOption for listOffsets(): " + entry.getValue()); + } + } + + Map result = new HashMap<>(); + Set allTopics = new HashSet<>(); + allTopics.addAll(earliestByTopic.keySet()); + allTopics.addAll(latestByTopic.keySet()); + + for (String topic : allTopics) { + subscribe(topic); + + Set earliestPartitions = earliestByTopic.getOrDefault(topic, new HashSet<>()); + if (!earliestPartitions.isEmpty()) { + Map offsets = memqConsumer.getEarliestOffsets(earliestPartitions); + for (Map.Entry e : offsets.entrySet()) { + TopicRn topicRn = MetadataUtils.createTopicRn(topicUri, topic); + result.put(createMemqTopicUriPartition(topicRn, e.getKey()), kafkaOffsetToComposite(e.getValue())); + } + } + + Set latestPartitions = latestByTopic.getOrDefault(topic, new HashSet<>()); + if (!latestPartitions.isEmpty()) { + Map offsets = memqConsumer.getLatestOffsets(latestPartitions); + for (Map.Entry e : offsets.entrySet()) { + TopicRn topicRn = MetadataUtils.createTopicRn(topicUri, topic); + result.put(createMemqTopicUriPartition(topicRn, e.getKey()), kafkaOffsetToComposite(e.getValue())); + } + } + } + return result; + } + + @Override + public Map listOffsetsForTimestamps( + Map topicUriPartitionsAndTimes, + Duration duration + ) throws ExecutionException, InterruptedException, TimeoutException { + Map> timestampsByTopic = new HashMap<>(); + + for (Map.Entry entry : topicUriPartitionsAndTimes.entrySet()) { + TopicUriPartition tup = entry.getKey(); + String topic = tup.getTopicUri().getTopic(); + timestampsByTopic.computeIfAbsent(topic, k -> new HashMap<>()) + .put(tup.getPartition(), entry.getValue()); + } + + Map result = new HashMap<>(); + for (Map.Entry> entry : timestampsByTopic.entrySet()) { + String topic = entry.getKey(); + subscribe(topic); + Map offsets = memqConsumer.offsetsOfTimestamps(entry.getValue()); + for (Map.Entry offsetEntry : offsets.entrySet()) { + TopicRn topicRn = MetadataUtils.createTopicRn(topicUri, topic); + result.put( + createMemqTopicUriPartition(topicRn, offsetEntry.getKey()), + kafkaOffsetToComposite(offsetEntry.getValue()) + ); + } + } + return result; + } + + @Override + public Map listOffsetsForConsumerGroup( + String consumerGroupId, + Collection topicUriPartitions, + Duration duration + ) throws ExecutionException, InterruptedException, TimeoutException { + Map> partitionsByTopic = new HashMap<>(); + for (TopicUriPartition tup : topicUriPartitions) { + String topic = tup.getTopicUri().getTopic(); + partitionsByTopic.computeIfAbsent(topic, k -> new HashSet<>()).add(tup.getPartition()); + } + + Map result = new HashMap<>(); + for (Map.Entry> entry : partitionsByTopic.entrySet()) { + String topic = entry.getKey(); + subscribe(topic); + for (int partition : entry.getValue()) { + long committedOffset = memqConsumer.committed(partition); + if (committedOffset == -1L) { + logger.warn( + "Consumer group {} has no committed offset for topic {} partition {}", + consumerGroupId, topic, partition + ); + continue; + } + TopicRn topicRn = MetadataUtils.createTopicRn(topicUri, topic); + result.put(createMemqTopicUriPartition(topicRn, partition), kafkaOffsetToComposite(committedOffset)); + } + } + return result; + } + + @Override + public void close() throws IOException { + if (memqConsumer != null) + memqConsumer.close(); + logger.info("Closed PscMemqMetadataClient"); + } + + private void subscribe(String topic) throws ExecutionException { + try { + memqConsumer.subscribe(topic); + } catch (Exception e) { + throw new ExecutionException("Failed to subscribe to Memq topic " + topic, e); + } + } + + private TopicUriPartition createMemqTopicUriPartition(TopicRn topicRn, int partition) { + return new TopicUriPartition( + new MemqTopicUri(new BaseTopicUri(topicUri.getProtocol(), topicRn)), partition); + } + + private TopicUriPartition createMemqTopicUriPartition(TopicUri topicUri, int partition) { + return new TopicUriPartition(new MemqTopicUri(topicUri), partition); + } + + private static long kafkaOffsetToComposite(long kafkaOffset) { + return new MemqOffset(kafkaOffset, 0).toLong(); + } +} diff --git a/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMemqMetadataClientCreator.java b/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMemqMetadataClientCreator.java new file mode 100644 index 00000000..d5645e1b --- /dev/null +++ b/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMemqMetadataClientCreator.java @@ -0,0 +1,37 @@ +package com.pinterest.psc.metadata.creation; + +import com.pinterest.psc.common.PscUtils; +import com.pinterest.psc.common.TopicUri; +import com.pinterest.psc.config.PscConfigurationInternal; +import com.pinterest.psc.consumer.memq.MemqTopicUri; +import com.pinterest.psc.environment.Environment; +import com.pinterest.psc.exception.startup.ConfigurationException; +import com.pinterest.psc.exception.startup.TopicUriSyntaxException; +import com.pinterest.psc.logging.PscLogger; +import com.pinterest.psc.metadata.client.memq.PscMemqMetadataClient; + +/** + * A class that creates a {@link com.pinterest.psc.metadata.client.PscBackendMetadataClient} for Memq. + */ +@PscMetadataClientCreatorPlugin(backend = PscUtils.BACKEND_TYPE_MEMQ, priority = 1) +public class PscMemqMetadataClientCreator extends PscBackendMetadataClientCreator { + + private static final PscLogger logger = PscLogger.getLogger(PscMemqMetadataClientCreator.class); + + @Override + public PscMemqMetadataClient create(Environment env, PscConfigurationInternal pscConfigurationInternal, TopicUri clusterUri) throws ConfigurationException { + logger.info("Creating Memq metadata client for clusterUri: " + clusterUri); + PscMemqMetadataClient pscMemqMetadataClient = new PscMemqMetadataClient(); + pscMemqMetadataClient.initialize( + clusterUri, + env, + pscConfigurationInternal + ); + return pscMemqMetadataClient; + } + + @Override + public TopicUri validateBackendTopicUri(TopicUri topicUri) throws TopicUriSyntaxException { + return MemqTopicUri.validate(topicUri); + } +} diff --git a/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMetadataClientCreatorManager.java b/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMetadataClientCreatorManager.java index 0834411e..5884f29f 100644 --- a/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMetadataClientCreatorManager.java +++ b/psc/src/main/java/com/pinterest/psc/metadata/creation/PscMetadataClientCreatorManager.java @@ -7,10 +7,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.TreeMap; /** * Manages the different {@link PscBackendMetadataClientCreator} implementations and provides a registry of them. @@ -28,7 +28,7 @@ public class PscMetadataClientCreatorManager { private static Map> findAndRegisterMetadataClientCreators(String packageName) { synchronized (PscUtils.lock) { - Map> backendCreatorRegistry = new HashMap<>(); + Map> backendCreatorRegistry = new TreeMap<>(); Reflections reflections = new Reflections(packageName.trim()); Set> annotatedClasses = reflections.getTypesAnnotatedWith(PscMetadataClientCreatorPlugin.class); for (Class annotatedClass : annotatedClasses) { diff --git a/psc/src/main/java/com/pinterest/psc/metrics/PscMetricRegistryManager.java b/psc/src/main/java/com/pinterest/psc/metrics/PscMetricRegistryManager.java index 7598701c..a66bdc67 100644 --- a/psc/src/main/java/com/pinterest/psc/metrics/PscMetricRegistryManager.java +++ b/psc/src/main/java/com/pinterest/psc/metrics/PscMetricRegistryManager.java @@ -6,7 +6,7 @@ import com.codahale.metrics.MetricFilter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.ScheduledReporter; -import com.codahale.metrics.SlidingTimeWindowArrayReservoir; +import com.codahale.metrics.ExponentiallyDecayingReservoir; import com.codahale.metrics.Snapshot; import com.codahale.metrics.jvm.CachedThreadStatesGaugeSet; import com.codahale.metrics.jvm.GarbageCollectorMetricSet; @@ -224,7 +224,7 @@ public void updateHistogramMetric(TopicUri topicUri, if (metricRegistry != null) { metricRegistry.histogram(metricKey, () -> new Histogram( - new SlidingTimeWindowArrayReservoir(1, TimeUnit.MINUTES) + new ExponentiallyDecayingReservoir() ) ).update(metricValue); } diff --git a/psc/src/test/java/com/pinterest/psc/consumer/memq/TestMemqOffsetRoundTrip.java b/psc/src/test/java/com/pinterest/psc/consumer/memq/TestMemqOffsetRoundTrip.java new file mode 100644 index 00000000..277a6f7c --- /dev/null +++ b/psc/src/test/java/com/pinterest/psc/consumer/memq/TestMemqOffsetRoundTrip.java @@ -0,0 +1,83 @@ +package com.pinterest.psc.consumer.memq; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Verifies that raw Kafka notification offsets survive a round-trip through + * the composite MemqOffset encoding used by PscMemqConsumer. + * + * Bug context: offsetsForTimes / endOffsets / startOffsets were returning raw + * Kafka offsets, but seekToOffset decodes them as composite MemqOffsets + * (bit-shifting right by 19). Without wrapping via kafkaOffsetToComposite, + * a raw offset like 2205646 would be decoded as batch=4 instead of batch=2205646. + */ +public class TestMemqOffsetRoundTrip { + + @Test + public void testRawKafkaOffsetRoundTrips() { + long rawKafkaOffset = 2205646L; + + long composite = new MemqOffset(rawKafkaOffset, 0).toLong(); + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(composite); + + assertEquals(rawKafkaOffset, decoded.getBatchOffset()); + assertEquals(0, decoded.getMessageOffset()); + } + + @Test + public void testRawOffsetWithoutEncodingIsCorrupted() { + long rawKafkaOffset = 2205646L; + + // Decoding a raw offset directly (the old bug) produces wrong batch offset + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(rawKafkaOffset); + + // 2205646 >>> 19 = 4, not 2205646 + assertEquals(4, decoded.getBatchOffset(), + "Raw offset decoded without encoding should lose upper bits"); + } + + @Test + public void testSmallOffsetRoundTrips() { + long rawKafkaOffset = 42L; + + long composite = new MemqOffset(rawKafkaOffset, 0).toLong(); + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(composite); + + assertEquals(rawKafkaOffset, decoded.getBatchOffset()); + assertEquals(0, decoded.getMessageOffset()); + } + + @Test + public void testZeroOffsetRoundTrips() { + long composite = new MemqOffset(0, 0).toLong(); + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(composite); + + assertEquals(0, decoded.getBatchOffset()); + assertEquals(0, decoded.getMessageOffset()); + } + + @Test + public void testLargeOffsetRoundTrips() { + long rawKafkaOffset = 10_000_000L; + + long composite = new MemqOffset(rawKafkaOffset, 0).toLong(); + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(composite); + + assertEquals(rawKafkaOffset, decoded.getBatchOffset()); + assertEquals(0, decoded.getMessageOffset()); + } + + @Test + public void testCompositeWithMessageOffsetRoundTrips() { + long batchOffset = 500L; + int messageOffset = 1234; + + long composite = new MemqOffset(batchOffset, messageOffset).toLong(); + MemqOffset decoded = MemqOffset.convertPscOffsetToMemqOffset(composite); + + assertEquals(batchOffset, decoded.getBatchOffset()); + assertEquals(messageOffset, decoded.getMessageOffset()); + } +}