From 9a3bc1da1c109a32b27b6fdb67351c25e7b8a2ec Mon Sep 17 00:00:00 2001 From: Ramanujam Date: Fri, 22 May 2026 18:20:25 +0530 Subject: [PATCH 1/3] fix: cast binary-payload columns to BinaryType before mapInPandas The features and metadata columns can contain raw binary data stored as StringType. PyArrow fails UTF-8 validation on these during Arrow-to-Pandas serialization in mapInPandas, throwing ArrowException before _decode_batch code even runs. - Cast features and metadata to BinaryType before mapInPandas so Arrow serializes them as raw bytes without UTF-8 validation - Handle bytes/bytearray input in _decode_batch for features_data - Handle bytes/bytearray input in _extract_metadata_byte for metadata Co-Authored-By: Claude Opus 4.6 (1M context) --- .../inference_logging_client/__init__.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 7efa8f2d..115703c0 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -237,6 +237,11 @@ def _extract_metadata_byte(metadata_data, json_module, base64_module) -> int: return 0 except (TypeError, ValueError): pass + if isinstance(metadata_data, (bytes, bytearray)): + try: + metadata_data = metadata_data.decode("utf-8") + except (UnicodeDecodeError, ValueError): + return 0 if isinstance(metadata_data, str): try: parsed = json_module.loads(metadata_data) @@ -425,7 +430,12 @@ def _decode_batch(iterator): feature_schema = get_feature_schema(mp_config_id, version, inference_host) except Exception: continue - if isinstance(features_data, str): + if isinstance(features_data, (bytes, bytearray)): + try: + features_list = json.loads(features_data.decode("utf-8")) + except (json.JSONDecodeError, ValueError, TypeError, UnicodeDecodeError): + continue + elif isinstance(features_data, str): try: features_list = json.loads(features_data) except (json.JSONDecodeError, ValueError, TypeError): @@ -522,6 +532,14 @@ def _decode_batch(iterator): out_pdf = pd.DataFrame(out_rows, columns=all_columns_ordered) yield out_pdf + # Cast binary-payload columns to BinaryType so Arrow serializes them as raw + # bytes instead of attempting UTF-8 string decoding (which fails on binary payloads) + from pyspark.sql.types import BinaryType as _BinaryType + from pyspark.sql import functions as F + for _col_name in (features_column, metadata_column): + if not isinstance(df.schema[_col_name].dataType, _BinaryType): + df = df.withColumn(_col_name, F.col(_col_name).cast(_BinaryType())) + n_partitions = num_partitions if num_partitions is not None else 10000 df_repart = df.repartition(n_partitions) From 7ae9f15cbb3deb4d1b5c963ce7b578090561f852 Mon Sep 17 00:00:00 2001 From: Ramanujam Date: Fri, 22 May 2026 19:43:58 +0530 Subject: [PATCH 2/3] feat: add v2 raw-proto wire format support and BinaryType fixes - Add _split_raw_proto_entities and _is_raw_proto_wire_format helpers to auto-detect and decode the v2 binary framing format [{}, ...] - Cast both features and metadata columns to BinaryType before mapInPandas to prevent ArrowException on non-UTF-8 binary payloads - Handle bytes input in _extract_metadata_byte from BinaryType cast - Stringify parent_entity values to prevent ArrowTypeError - Port v0.3.7 additions: decode_mplog_proto_dataframe, decode_mplog_proto_csv, _normalize_schema Co-Authored-By: Claude Opus 4.6 (1M context) --- .../inference_logging_client/__init__.py | 721 +++++++++++++++++- .../inference_logging_client/pyproject.toml | 2 +- 2 files changed, 695 insertions(+), 28 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 115703c0..942db7d1 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -48,7 +48,7 @@ from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte -__version__ = "0.3.1" +__version__ = "0.3.8" # Maximum supported schema version (4 bits = 0-15) _MAX_SCHEMA_VERSION = 15 @@ -56,6 +56,8 @@ __all__ = [ "decode_mplog", "decode_mplog_dataframe", + "decode_mplog_proto_dataframe", + "decode_mplog_proto_csv", "get_mplog_metadata", "get_feature_schema", "clear_schema_cache", @@ -99,6 +101,58 @@ def _decompress_zstd(data: bytes) -> bytes: return data +def _split_raw_proto_entities(raw: bytes) -> list: + """Split v2 raw-proto wire format into per-entity byte chunks. + + Wire format: [{}, {}, ...] + Separator between entities is b'},{' or b'}, {'. + """ + if len(raw) < 4: + return [raw] if raw else [] + + if raw[0:1] == b'[' and raw[1:2] == b'{': + inner = raw[2:] + if inner.endswith(b'}]'): + inner = inner[:-2] + elif inner.endswith(b'}'): + inner = inner[:-1] + + chunks = [] + start = 0 + i = 0 + while i < len(inner): + if inner[i:i + 1] == b'}': + rest = inner[i + 1:i + 4] + if rest.startswith(b', {') or rest.startswith(b',{'): + chunks.append(inner[start:i]) + skip = 3 if rest.startswith(b', ') else 2 + start = i + 1 + skip + i = start + continue + i += 1 + if start < len(inner): + chunks.append(inner[start:]) + return [c for c in chunks if c] + + if raw[0:1] == b'{': + inner = raw[1:] + if inner.endswith(b'}'): + inner = inner[:-1] + return [inner] if inner else [] + + return [raw] + + +def _is_raw_proto_wire_format(data) -> bool: + """Detect v2 raw-proto: starts with '[{' followed by non-'"' byte.""" + if isinstance(data, str): + return len(data) >= 3 and data[0] == '[' and data[1] == '{' and data[2] != '"' + if isinstance(data, (bytes, bytearray, memoryview)): + raw = bytes(data) + return len(raw) >= 3 and raw[0:1] == b'[' and raw[1:2] == b'{' and raw[2:3] != b'"' + return False + + def decode_mplog( log_data: bytes, model_proxy_id: str, @@ -237,6 +291,7 @@ def _extract_metadata_byte(metadata_data, json_module, base64_module) -> int: return 0 except (TypeError, ValueError): pass + # Handle bytes/bytearray from BinaryType cast if isinstance(metadata_data, (bytes, bytearray)): try: metadata_data = metadata_data.decode("utf-8") @@ -280,6 +335,10 @@ def decode_mplog_dataframe( """ Decode MPLog features from a Spark DataFrame with specific column structure. + Supports two wire formats (auto-detected per row): + - JSON envelope: [{"encoded_features": "base64..."}] (original format) + - v2 raw-proto: [{}, {}, ...] (binary framing) + Expected DataFrame columns: - prism_ingested_at, prism_extracted_at, created_at - entities, features, metadata @@ -374,7 +433,7 @@ def decode_mplog_dataframe( "hour", ] _reserved_columns = {"entity_id"} | {c for c in row_metadata_columns if c in df_columns} - + # Build full output schema: entity_id + metadata cols + (optionally restricted) feature names all_feature_names = set() for feat_list in schema_cache.values(): @@ -415,6 +474,12 @@ def _decode_batch(iterator): if features_data is None: continue metadata_data = _safe_get(row, metadata_column) + # Handle bytes from BinaryType cast + if isinstance(metadata_data, (bytes, bytearray)): + try: + metadata_data = metadata_data.decode("utf-8") + except (UnicodeDecodeError, ValueError): + continue metadata_byte = _extract_metadata_byte(metadata_data, json, base64) _, version, _ = unpack_metadata_byte(metadata_byte) if not (0 <= version <= _MAX_SCHEMA_VERSION): @@ -430,20 +495,8 @@ def _decode_batch(iterator): feature_schema = get_feature_schema(mp_config_id, version, inference_host) except Exception: continue - if isinstance(features_data, (bytes, bytearray)): - try: - features_list = json.loads(features_data.decode("utf-8")) - except (json.JSONDecodeError, ValueError, TypeError, UnicodeDecodeError): - continue - elif isinstance(features_data, str): - try: - features_list = json.loads(features_data) - except (json.JSONDecodeError, ValueError, TypeError): - continue - else: - features_list = features_data - if not isinstance(features_list, list): - continue + + # Parse entities entities_val = None if "entities" in df_columns: entities_raw = _safe_get(row, "entities") @@ -471,7 +524,63 @@ def _decode_batch(iterator): if isinstance(parent_val, list): parent_entity_val = parent_val[0] if len(parent_val) == 1 else str(parent_val) if len(parent_val) > 1 else None else: - parent_entity_val = parent_val + parent_entity_val = str(parent_val) + + # --- v2 raw-proto wire format: [{}, ...] --- + if _is_raw_proto_wire_format(features_data): + if isinstance(features_data, str): + features_data = features_data.encode("utf-8", errors="surrogateescape") + elif isinstance(features_data, memoryview): + features_data = bytes(features_data) + entity_chunks = _split_raw_proto_entities(features_data) + for i, chunk in enumerate(entity_chunks): + if decompress: + chunk = _decompress_zstd(chunk) + try: + decoded_features = decode_proto_features( + chunk, feature_schema, needed_columns=needed_columns + ) + except Exception: + continue + entity_id = str(entities_val[i]) if entities_val and i < len(entities_val) else f"entity_{i}" + result_row = {"entity_id": entity_id} + for k, v in decoded_features.items(): + if k in _reserved_columns: + continue + if v is None: + result_row[k] = None + elif isinstance(v, (list, tuple)): + result_row[k] = str(v) + elif isinstance(v, bytes): + result_row[k] = v.hex() + else: + result_row[k] = str(v) + for col in row_metadata_columns: + if col in df_columns: + result_row[col] = _safe_get(row, col) + if parent_entity_val is not None: + result_row["parent_entity"] = str(parent_entity_val) + for col in all_columns_ordered: + if col not in result_row: + result_row[col] = None + out_rows.append(result_row) + continue # next row + + # --- JSON envelope format: [{"encoded_features": "base64..."}] --- + if isinstance(features_data, (bytes, bytearray, memoryview)): + try: + features_data = bytes(features_data).decode("utf-8") + except UnicodeDecodeError: + continue + if isinstance(features_data, str): + try: + features_list = json.loads(features_data) + except (json.JSONDecodeError, ValueError, TypeError): + continue + else: + features_list = features_data + if not isinstance(features_list, list): + continue for i, feature_item in enumerate(features_list): if not isinstance(feature_item, dict): continue @@ -518,11 +627,9 @@ def _decode_batch(iterator): result_row[k] = str(v) for col in row_metadata_columns: if col in df_columns: - # Pass through as-is to preserve original types - # (LongType, TimestampType, etc.) result_row[col] = _safe_get(row, col) if parent_entity_val is not None: - result_row["parent_entity"] = parent_entity_val + result_row["parent_entity"] = str(parent_entity_val) # Fill missing schema columns with None for col in all_columns_ordered: if col not in result_row: @@ -532,15 +639,14 @@ def _decode_batch(iterator): out_pdf = pd.DataFrame(out_rows, columns=all_columns_ordered) yield out_pdf - # Cast binary-payload columns to BinaryType so Arrow serializes them as raw - # bytes instead of attempting UTF-8 string decoding (which fails on binary payloads) + n_partitions = num_partitions if num_partitions is not None else 10000 + # Cast binary-payload columns to BinaryType so pyarrow does not UTF-8 + # validate them at the Arrow->pandas boundary. + from pyspark.sql import functions as _F from pyspark.sql.types import BinaryType as _BinaryType - from pyspark.sql import functions as F for _col_name in (features_column, metadata_column): - if not isinstance(df.schema[_col_name].dataType, _BinaryType): - df = df.withColumn(_col_name, F.col(_col_name).cast(_BinaryType())) - - n_partitions = num_partitions if num_partitions is not None else 10000 + if not isinstance(input_field_map.get(_col_name), _BinaryType): + df = df.withColumn(_col_name, _F.col(_col_name).cast(_BinaryType())) df_repart = df.repartition(n_partitions) batch_limit = max_records_per_batch if max_records_per_batch is not None else 200 @@ -572,3 +678,564 @@ def _decode_batch(iterator): feature_cols = [c for c in result_columns if c not in metadata_cols] column_order = metadata_cols + feature_cols return result_df.select(column_order) + + +def _normalize_schema(schema) -> "list[FeatureInfo]": + """Accept either a list[FeatureInfo], a list of raw dicts, or the + inference-service JSON shape ``{"data": [...]}`` and return list[FeatureInfo]. + + Raw dict items must carry ``feature_name`` and ``feature_type`` keys + (matching the inference service response). Order is preserved and used + to assign the ``index`` of each FeatureInfo, which is the proto field + position used by the decoder. + """ + if schema is None: + raise ValueError("schema must not be None") + + # Unwrap {"data": [...]} JSON shape + if isinstance(schema, dict): + if "data" not in schema: + raise ValueError("schema dict must contain a 'data' key") + items = schema["data"] + else: + items = schema + + if not isinstance(items, list) or not items: + raise ValueError("schema must be a non-empty list (or dict with non-empty 'data')") + + # Already FeatureInfo objects + if all(isinstance(it, FeatureInfo) for it in items): + return items + + normalized: list[FeatureInfo] = [] + for idx, item in enumerate(items): + if isinstance(item, FeatureInfo): + normalized.append(item) + continue + if not isinstance(item, dict): + raise ValueError( + f"schema item at index {idx} must be FeatureInfo or dict, got {type(item).__name__}" + ) + name = item.get("feature_name") or item.get("name") + feature_type = item.get("feature_type") + if not name or not feature_type: + raise ValueError( + f"schema item at index {idx} missing 'feature_name'/'name' or 'feature_type'" + ) + normalized.append(FeatureInfo(name=name, feature_type=feature_type, index=idx)) + return normalized + + +def decode_mplog_proto_dataframe( + df: "SparkDataFrame", + spark: "SparkSession", + schema, + decompress: bool = True, + features_column: str = "features", + mp_config_id_column: str = "mp_config_id", + num_partitions: Optional[int] = None, + max_records_per_batch: Optional[int] = None, + needed_columns: Optional[Collection[str]] = None, +) -> "SparkDataFrame": + """ + Decode MPLog features from a Spark DataFrame using a caller-supplied schema. + + Format is always PROTO. No schema fetch is performed and no inference service + is contacted. The caller is responsible for passing the correct schema for the + encoded payloads in the DataFrame; all rows are decoded against the same schema. + + Supports both JSON envelope and v2 raw-proto wire formats (auto-detected). + + Expected DataFrame columns: + - features (encoded payloads; JSON-array-of-base64 strings or raw-proto framed bytes) + - mp_config_id + - optional: entities, parent_entity + - optional row-metadata: prism_ingested_at, prism_extracted_at, created_at, + tracking_id, user_id, year, month, day, hour + + Args: + df: Input Spark DataFrame. + spark: The SparkSession to use for creating the result DataFrame. + schema: Schema applied to all rows. Accepted shapes: + - list[FeatureInfo] + - list[dict] with keys 'feature_name' (or 'name') and 'feature_type' + - dict {"data": [...]} matching the inference service JSON response + Order is used to assign the proto field index; do not reorder. + decompress: Whether to attempt zstd decompression on each encoded payload. + features_column: Name of the column containing encoded features (default: "features"). + mp_config_id_column: Name of the column containing model proxy config ID + (default: "mp_config_id"). Pass-through column; not used to look up schema. + num_partitions: Number of partitions for distributed decode. Default 10000. + max_records_per_batch: Max rows per Arrow batch in mapInPandas. Default 50. + needed_columns: Optional set or list of feature names to include. If provided, + only these columns are decoded and returned. + + Returns: + Spark DataFrame with entity_id as first column, followed by available row-metadata + columns, followed by feature columns. + + Example: + >>> from pyspark.sql import SparkSession + >>> from inference_logging_client import ( + ... decode_mplog_proto_dataframe, get_feature_schema, + ... ) + >>> spark = SparkSession.builder.appName("decode").getOrCreate() + >>> df = spark.read.parquet("logs.parquet") + >>> schema = get_feature_schema("my-model", 1) + >>> decoded_df = decode_mplog_proto_dataframe(df, spark, schema=schema) + >>> decoded_df.show() + """ + import base64 + import json + + schema = _normalize_schema(schema) + + # Check if DataFrame is empty (avoid full count: use limit(1)) + if df.limit(1).count() == 0: + from pyspark.sql.types import StructType + return spark.createDataFrame([], StructType([])) + + # Validate required columns + df_columns = df.columns + required_columns = [features_column, mp_config_id_column] + missing_columns = [c for c in required_columns if c not in df_columns] + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + row_metadata_columns = [ + "prism_ingested_at", + "prism_extracted_at", + "created_at", + "mp_config_id", + "parent_entity", + "tracking_id", + "user_id", + "year", + "month", + "day", + "hour", + ] + _reserved_columns = {"entity_id"} | {c for c in row_metadata_columns if c in df_columns} + + # Build output schema: entity_id + available metadata cols + feature names + all_feature_names = {f.name for f in schema} + if needed_columns is not None: + all_feature_names = all_feature_names & set(needed_columns) + metadata_cols_in_schema = [c for c in row_metadata_columns if c in df_columns] + + from pyspark.sql.types import StringType, StructField, StructType + input_field_map = {field.name: field.dataType for field in df.schema.fields} + schema_fields = [StructField("entity_id", StringType(), True)] + for c in metadata_cols_in_schema: + original_type = input_field_map.get(c, StringType()) + schema_fields.append(StructField(c, original_type, True)) + for c in sorted(all_feature_names): + schema_fields.append(StructField(c, StringType(), True)) + full_schema = StructType(schema_fields) + all_columns_ordered = ["entity_id"] + metadata_cols_in_schema + sorted(all_feature_names) + + # Project to only the columns we actually need on workers + projected_cols = [ + c for c in ( + [features_column, mp_config_id_column, "entities"] + row_metadata_columns + ) + if c in df_columns + ] + seen = set() + projected_cols = [c for c in projected_cols if not (c in seen or seen.add(c))] + + # Cast features to BinaryType for safe Arrow serialization + from pyspark.sql import functions as _F + from pyspark.sql.types import BinaryType as _BinaryType + if not isinstance(input_field_map.get(features_column), _BinaryType): + df = df.withColumn(features_column, _F.col(features_column).cast(_BinaryType())) + + df_projected = df.select(*projected_cols) + + # Capture for closure + feature_schema = schema + + # --- Hot-path precomputation (runs once per call, used by every worker + # invocation of _decode_batch). Avoids redoing this work per row/entity. --- + has_entities_col = "entities" in df_columns + has_parent_entity_col = "parent_entity" in df_columns + metadata_cols_present = [c for c in row_metadata_columns if c in df_columns] + # Pre-built template dict avoids the "fill missing columns with None" + # loop per entity. We copy it per output row and overwrite the cells + # we actually have values for. + row_template = {c: None for c in all_columns_ordered} + + def _decode_batch(iterator): + # Imports inside the worker function — pyspark needs the function to + # be self-contained for cloudpickle, and free-variable callable refs + # tend to break pickling on some pyspark builds. + import base64 as _base64 + import json as _json + import pandas as pd + + _b64decode = _base64.b64decode + _json_loads = _json.loads + _decode_proto = decode_proto_features + _decompress = _decompress_zstd + + for pdf in iterator: + # Single conversion to list-of-dicts is dramatically faster than + # pandas.iterrows() for wide+long DataFrames. iterrows materializes + # a Series per row with per-cell type lookups; to_dict("records") + # walks the underlying numpy arrays once. + records = pdf.to_dict(orient="records") + out_rows = [] + out_rows_append = out_rows.append # local-bind for speed + + for row in records: + features_data = row.get(features_column) + if not features_data: + continue + + entities_val = None + if has_entities_col: + entities_raw = row.get("entities") + if entities_raw: + if isinstance(entities_raw, str): + try: + parsed = _json_loads(entities_raw) + entities_val = parsed if isinstance(parsed, list) else [entities_raw] + except (ValueError, TypeError): + entities_val = [entities_raw] + elif isinstance(entities_raw, list): + entities_val = entities_raw + else: + entities_val = [entities_raw] + + parent_entity_val = None + if has_parent_entity_col: + parent_val = row.get("parent_entity") + if parent_val: + if isinstance(parent_val, str): + try: + parent_val = _json_loads(parent_val) + except (ValueError, TypeError): + parent_val = [parent_val] + if isinstance(parent_val, list): + n_parents = len(parent_val) + if n_parents == 1: + parent_entity_val = parent_val[0] + elif n_parents > 1: + parent_entity_val = str(parent_val) + else: + parent_entity_val = str(parent_val) + + # Precompute the row-metadata snapshot once per input row — + # every entity expansion below shares the same values. + base_metadata = {c: row.get(c) for c in metadata_cols_present} + if parent_entity_val is not None and "parent_entity" in metadata_cols_present: + base_metadata["parent_entity"] = str(parent_entity_val) + + entities_len = len(entities_val) if entities_val else 0 + + # --- v2 raw-proto wire format --- + if _is_raw_proto_wire_format(features_data): + if isinstance(features_data, str): + features_data = features_data.encode("utf-8", errors="surrogateescape") + elif isinstance(features_data, (memoryview,)): + features_data = bytes(features_data) + entity_chunks = _split_raw_proto_entities(features_data) + for i, chunk in enumerate(entity_chunks): + if decompress: + chunk = _decompress(chunk) + try: + decoded_features = _decode_proto( + chunk, feature_schema, needed_columns=needed_columns + ) + except Exception: + continue + entity_id = ( + str(entities_val[i]) if i < entities_len else f"entity_{i}" + ) + result_row = row_template.copy() + result_row["entity_id"] = entity_id + if base_metadata: + result_row.update(base_metadata) + for k, v in decoded_features.items(): + if k in _reserved_columns: + continue + if v is None: + result_row[k] = None + elif type(v) is str: + result_row[k] = v + elif isinstance(v, (list, tuple)): + result_row[k] = str(v) + elif isinstance(v, bytes): + result_row[k] = v.hex() + else: + result_row[k] = str(v) + out_rows_append(result_row) + continue + + # --- JSON envelope format --- + if not isinstance(features_data, str): + continue + try: + features_list = _json_loads(features_data) + except (ValueError, TypeError): + continue + if not isinstance(features_list, list): + continue + + for i, feature_item in enumerate(features_list): + if not isinstance(feature_item, dict): + continue + encoded_features_b64 = feature_item.get("encoded_features") + if not encoded_features_b64: + continue + try: + encoded_bytes = _b64decode(encoded_features_b64) + except (ValueError, TypeError): + continue + if not encoded_bytes: + continue + + if decompress: + try: + working_data = _decompress(encoded_bytes) + except Exception: + continue + else: + working_data = encoded_bytes + + try: + decoded_features = _decode_proto( + working_data, feature_schema, needed_columns=needed_columns + ) + except Exception: + continue + + entity_id = ( + str(entities_val[i]) + if i < entities_len + else f"entity_{i}" + ) + + # Copy the prebuilt template instead of building a fresh + # dict and then filling all 322 missing keys. + result_row = row_template.copy() + result_row["entity_id"] = entity_id + if base_metadata: + result_row.update(base_metadata) + + # Stringify decoded values for output (output schema is + # all StringType for feature cols). Skip reserved cols. + for k, v in decoded_features.items(): + if k in _reserved_columns: + continue + if v is None: + result_row[k] = None + elif type(v) is str: + result_row[k] = v + elif isinstance(v, (list, tuple)): + result_row[k] = str(v) + elif isinstance(v, bytes): + result_row[k] = v.hex() + else: + result_row[k] = str(v) + + out_rows_append(result_row) + + if out_rows: + yield pd.DataFrame(out_rows, columns=all_columns_ordered) + + n_partitions = num_partitions if num_partitions is not None else 10000 + df_repart = df_projected.repartition(n_partitions) + + batch_limit = max_records_per_batch if max_records_per_batch is not None else 50 + prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) + try: + result_df = df_repart.mapInPandas(_decode_batch, full_schema) + finally: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", prev_max_records or "10000") + + # Reorder columns: entity_id first, then metadata, then features + result_columns = result_df.columns + metadata_cols = ["entity_id"] + for col in row_metadata_columns: + if col in result_columns: + metadata_cols.append(col) + feature_cols = [c for c in result_columns if c not in metadata_cols] + column_order = metadata_cols + feature_cols + return result_df.select(column_order) + + +def decode_mplog_proto_csv( + input_csv: str, + output_csv: str, + schema, + decompress: bool = True, + features_column: str = "features", + mp_config_id_column: str = "mp_config_id", + needed_columns: Optional[Collection[str]] = None, +) -> int: + """ + Decode an MPLog CSV file directly to another CSV, without Spark. + + Reads the input CSV row-by-row, decodes each row's encoded entities using + the caller-supplied PROTO schema, and writes one decoded row per entity to + output_csv. Pure-Python; uses only csv/json/base64 + decode_proto_features. + + Expected input columns: features, mp_config_id, optionally entities, + parent_entity, and the row-metadata columns (prism_ingested_at, etc). + + Args: + input_csv: Path to the input CSV. + output_csv: Path where the decoded CSV will be written. + schema: Same shapes accepted by decode_mplog_proto_dataframe: + list[FeatureInfo], list[dict], or {"data": [...]}. + decompress: Attempt zstd decompression per encoded payload. + features_column: Column with the encoded features JSON. + mp_config_id_column: Pass-through column name. + needed_columns: Optional set of feature names to keep. + + Returns: + Number of decoded rows written. + """ + import base64 + import csv as _csv + import json + import sys as _sys + + # MPLog features cells can be multi-MB; lift the csv field-size cap. + try: + _csv.field_size_limit(_sys.maxsize) + except OverflowError: + _csv.field_size_limit(2**31 - 1) + + schema_list = _normalize_schema(schema) + needed_set = set(needed_columns) if needed_columns is not None else None + + feature_names = [f.name for f in schema_list] + if needed_set is not None: + feature_names = [n for n in feature_names if n in needed_set] + + row_metadata_columns = [ + "prism_ingested_at", + "prism_extracted_at", + "created_at", + "mp_config_id", + "parent_entity", + "tracking_id", + "user_id", + "year", + "month", + "day", + "hour", + ] + + with open(input_csv, "r", newline="", encoding="utf-8") as f_in: + reader = _csv.DictReader(f_in) + if reader.fieldnames is None: + raise ValueError(f"Input CSV {input_csv} has no header row") + input_columns = set(reader.fieldnames) + + if features_column not in input_columns: + raise ValueError(f"Missing required column: {features_column}") + + present_metadata_cols = [c for c in row_metadata_columns if c in input_columns] + out_columns = ["entity_id"] + present_metadata_cols + sorted(feature_names) + + n_written = 0 + with open(output_csv, "w", newline="", encoding="utf-8") as f_out: + writer = _csv.DictWriter(f_out, fieldnames=out_columns, extrasaction="ignore") + writer.writeheader() + + for row in reader: + features_data = row.get(features_column) + if not features_data: + continue + try: + features_list = json.loads(features_data) + except (json.JSONDecodeError, ValueError, TypeError): + continue + if not isinstance(features_list, list): + continue + + entities_val = None + if "entities" in input_columns: + entities_raw = row.get("entities") + if entities_raw: + try: + parsed = json.loads(entities_raw) + entities_val = parsed if isinstance(parsed, list) else [parsed] + except (json.JSONDecodeError, ValueError): + entities_val = [entities_raw] + + parent_entity_val = None + if "parent_entity" in input_columns: + parent_raw = row.get("parent_entity") + if parent_raw: + try: + parsed = json.loads(parent_raw) + if isinstance(parsed, list): + parent_entity_val = ( + parsed[0] if len(parsed) == 1 + else str(parsed) if len(parsed) > 1 + else None + ) + else: + parent_entity_val = parsed + except (json.JSONDecodeError, ValueError): + parent_entity_val = parent_raw + + base_metadata = {c: row.get(c) for c in present_metadata_cols} + + for i, feature_item in enumerate(features_list): + if not isinstance(feature_item, dict): + continue + encoded_b64 = feature_item.get("encoded_features", "") + if not encoded_b64: + continue + try: + encoded_bytes = base64.b64decode(encoded_b64) + except (ValueError, TypeError): + continue + if not encoded_bytes: + continue + + working_data = encoded_bytes + if decompress: + try: + working_data = _decompress_zstd(encoded_bytes) + except Exception: + continue + + try: + decoded = decode_proto_features( + working_data, schema_list, needed_columns=needed_set + ) + except Exception: + continue + + entity_id = ( + str(entities_val[i]) + if entities_val and i < len(entities_val) + else f"entity_{i}" + ) + + out_row = {"entity_id": entity_id} + out_row.update(base_metadata) + if parent_entity_val is not None and "parent_entity" in present_metadata_cols: + out_row["parent_entity"] = parent_entity_val + for k, v in decoded.items(): + if needed_set is not None and k not in needed_set: + continue + if v is None: + out_row[k] = "" + elif isinstance(v, (list, tuple)): + out_row[k] = str(v) + elif isinstance(v, bytes): + out_row[k] = v.hex() + else: + out_row[k] = v + writer.writerow(out_row) + n_written += 1 + + return n_written diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 13366434..583fed90 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.3.1" +version = "0.3.8" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" From e2651883bd38987383413bdd5827dec7f3f38372 Mon Sep 17 00:00:00 2001 From: "jaya.kommuru" Date: Mon, 25 May 2026 15:25:47 +0530 Subject: [PATCH 3/3] chore(inference-logging-client): loosen pyspark pin and bump to 0.3.9 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Loosen the pyspark dependency from `==3.3.0` to `>=3.3,<4` so the client co-installs cleanly with packages targeting newer PySpark minors (e.g. meeupp-spark-lib-python 1.2.2 → pyspark~=3.5.3), which previously failed pip resolution. PySpark 3.x is API-stable for the surface this client uses (Spark SQL DataFrame / mapInPandas / Arrow types). Bump to 0.3.9 because 0.3.8 is already on PyPI. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../inference_logging_client/__init__.py | 2 +- py-sdk/inference_logging_client/pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 942db7d1..2bac3f3d 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -48,7 +48,7 @@ from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte -__version__ = "0.3.8" +__version__ = "0.3.9" # Maximum supported schema version (4 bits = 0-15) _MAX_SCHEMA_VERSION = 15 diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 583fed90..a989dc73 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.3.8" +version = "0.3.9" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" @@ -28,7 +28,7 @@ classifiers = [ ] dependencies = [ - "pyspark==3.3.0", + "pyspark>=3.3,<4", "pyarrow>=5.0.0", "zstandard>=0.15.0", ]