From c95aab35437669ab476cb08a0293e447b64a4586 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 25 Jun 2026 21:58:59 +0000 Subject: [PATCH] test: add ASV microbenchmark for SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF --- python/benchmarks/bench_eval_type.py | 165 +++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/python/benchmarks/bench_eval_type.py b/python/benchmarks/bench_eval_type.py index f220b804be0b2..cec2ead32be08 100644 --- a/python/benchmarks/bench_eval_type.py +++ b/python/benchmarks/bench_eval_type.py @@ -1981,3 +1981,168 @@ class TransformWithStatePandasUDFPeakmemBench( _TransformWithStatePandasBenchMixin, _PeakmemBenchBase ): pass + + +# -- SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF ---------------------------- +# Stateful streaming with Pandas plus an initial-state dataset. The UDF +# signature is ``(api_client, mode, key, state_values, init_states)`` and +# returns ``Iterator[pandas.DataFrame]``. +# +# Unlike the plain TWS variant, the input wire stream wraps two datasets into a +# single Arrow stream whose top-level schema is +# ``struct`` (see +# ``TransformWithStateInPySparkPythonInitialStateRunner``). Each batch carries +# either inputData or initState rows -- never both -- with the inactive column +# written as an all-null struct. Matching the JVM ``initData ++ data`` ordering, +# all initial-state batches are emitted first (initState populated), then all +# data batches (inputData populated). ``TransformWithStateInPandasInitStateSerializer`` +# regroups rows by the leading key column, so each key surfaces as an init-only +# call followed by a data-only call; the empty side of each call is filtered out +# before the UDF sees it. + + +class _TransformWithStatePandasInitStateBenchMixin(_TransformWithStatePandasBenchMixin): + """Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF. + + Reuses the plain-TWS scenario grid for the input data and seeds a small + initial-state dataset per group (``_INIT_ROWS_PER_GROUP`` rows sharing the + input schema). The initial-state deserialization cost (nested-struct flatten + plus per-key regrouping) is incurred during ``load_stream`` regardless of + whether the UDF reads ``init_states``. + """ + + # Initial state is small relative to the streamed data (one seeded chunk per + # key), so data deserialization stays the dominant cost -- mirroring + # production where initial state loads once and input data streams per batch. + _INIT_ROWS_PER_GROUP = 100 + + @classmethod + def _build_init_batches(cls, name): + """Build the initial-state Arrow batches for a scenario. + + Shares the input schema (same value columns) but with only + ``_INIT_ROWS_PER_GROUP`` rows per group, pre-sorted by the leading key. + """ + np.random.seed(7) + num_groups, _, num_value_cols, value_pool = cls._scenario_configs[name] + total_rows = num_groups * cls._INIT_ROWS_PER_GROUP + key_array = pa.array( + np.repeat(np.arange(num_groups, dtype=np.int32), cls._INIT_ROWS_PER_GROUP), + type=pa.int32(), + ) + value_arrays = [ + value_pool[i % len(value_pool)][0](total_rows) for i in range(num_value_cols) + ] + names = ["col_0"] + [f"col_{i + 1}" for i in range(num_value_cols)] + full_batch = pa.RecordBatch.from_arrays([key_array] + value_arrays, names=names) + batch_size = MockDataFactory.MAX_RECORDS_PER_BATCH + return [ + full_batch.slice(offset, min(batch_size, total_rows - offset)) + for offset in range(0, total_rows, batch_size) + ] + + @staticmethod + def _wrap_nested(flat_batch, struct_type, *, is_init): + """Wrap a flat batch into a ``struct`` batch. + + The populated side carries ``flat_batch``'s columns; the inactive side is + an all-null struct array of the same length, so ``flatten_columns`` in the + serializer treats it as empty. + """ + n = flat_batch.num_rows + populated = pa.StructArray.from_arrays( + [flat_batch.column(i) for i in range(flat_batch.num_columns)], + names=flat_batch.schema.names, + ) + null_struct = pa.array([None] * n, type=struct_type) + arrays = [null_struct, populated] if is_init else [populated, null_struct] + return pa.RecordBatch.from_arrays(arrays, names=["inputData", "initState"]) + + def _tws_init_identity(api_client, mode, key, state_values, init_states): + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + yield from state_values + + def _tws_init_sort(api_client, mode, key, state_values, init_states): + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + for pdf in state_values: + yield pdf.sort_values(pdf.columns[0]) + + def _tws_init_count(api_client, mode, key, state_values, init_states): + import pandas as pd + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPandasFuncMode, + ) + + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: + # state_values and init_states arrive on separate per-key calls; sum + # whichever is non-empty so both deserialization paths are counted. + total = sum(len(pdf) for pdf in state_values) + sum(len(pdf) for pdf in init_states) + if total: + yield pd.DataFrame({"col_0": [key[0]], "col_1": [total]}) + + # ret_type=None means "echo the full input schema": the init-state worker + # path does not project value columns, so identity/sort receive and return + # the key column too. count_udf re-emits (key, total) explicitly. + _udfs = { + "identity_udf": (_tws_init_identity, None), + "sort_udf": (_tws_init_sort, None), + "count_udf": ( + _tws_init_count, + StructType([StructField("col_0", IntegerType()), StructField("col_1", IntegerType())]), + ), + } + params = [list(_TransformWithStatePandasBenchMixin._scenario_configs), list(_udfs)] + param_names = ["scenario", "udf"] + + def _write_scenario(self, scenario, udf_name, buf): + data_batches, schema = self._build_scenario(scenario) + init_batches = self._build_init_batches(scenario) + udf_func, ret_type = self._udfs[udf_name] + if ret_type is None: + ret_type = schema + n_value_cols = len(schema.fields) - self._NUM_KEY_COLS + # Two arg-offset groups -- one for input data, one for initial state. + # Both datasets share the schema, so each resolves to key=[0], values=[1..n]. + arg_offsets = MockUDFFactory.make_grouped_arg_offsets( + self._NUM_KEY_COLS, n_value_cols + ) + MockUDFFactory.make_grouped_arg_offsets(self._NUM_KEY_COLS, n_value_cols) + grouping_key_schema = StructType(schema.fields[: self._NUM_KEY_COLS]) + # Wrap both datasets into the struct wire schema; + # the two structs share a type since the datasets share a schema. + struct_type = pa.StructArray.from_arrays( + [data_batches[0].column(i) for i in range(data_batches[0].num_columns)], + names=data_batches[0].schema.names, + ).type + nested_batches = [self._wrap_nested(b, struct_type, is_init=True) for b in init_batches] + [ + self._wrap_nested(b, struct_type, is_init=False) for b in data_batches + ] + MockProtocolWriter.write_worker_input( + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, + lambda b: MockProtocolWriter.write_udf_payload(udf_func, ret_type, arg_offsets, b), + lambda b: MockProtocolWriter.write_data_payload(iter(nested_batches), b), + buf, + eval_conf={ + "state_server_socket_port": str(_StubStateServer.get_port()), + "grouping_key_schema": grouping_key_schema.json(), + }, + ) + + +class TransformWithStatePandasInitStateUDFTimeBench( + _TransformWithStatePandasInitStateBenchMixin, _TimeBenchBase +): + pass + + +class TransformWithStatePandasInitStateUDFPeakmemBench( + _TransformWithStatePandasInitStateBenchMixin, _PeakmemBenchBase +): + pass