From dcb60f240c6ff6ff81a46ec344fcaff12b9f2389 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 29 Apr 2026 17:16:05 +0530 Subject: [PATCH 1/7] Add C++ DetectParamTypes + SQLExecuteFast pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move parameter type detection from Python into C++ using raw CPython type checks (PyLong_CheckExact, PyFloat_CheckExact, etc.). Merge the DetectParamTypes → BindParameters → SQLExecute pipeline into a single DDBCSQLExecuteFast call so ParamInfo never crosses the pybind11 boundary. - DetectParamTypes: handles int (range-detected), float, bool, str (unicode + geometry sniffing), bytes, datetime/date/time, Decimal (MONEY range + generic numeric), UUID, None, with fallback to string - SQLExecuteFast_wrap: single pipeline with GIL release, always uses SQLPrepare for parameterized queries - cursor.py: fast path routing when no setinputsizes overrides present; old DDBCSQLExecute path preserved for setinputsizes callers - Named constants: MAX_INLINE_CHAR, MAX_INLINE_BINARY, MAX_NUMERIC_PRECISION, MONEY/SMALLMONEY ranges, PARAM_C_TYPE_TEXT platform macro --- mssql_python/cursor.py | 82 +++-- mssql_python/pybind/ddbc_bindings.cpp | 446 ++++++++++++++++++++++++++ 2 files changed, 495 insertions(+), 33 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 05324875..783d0c69 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -1452,11 +1452,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state # Getting encoding setting encoding_settings = self._get_encoding_settings() - # Apply timeout if set (non-zero) - logger.debug("execute: Creating parameter type list") - param_info = ddbc_bindings.ParamInfo - parameters_type = [] - # Validate that inputsizes matches parameter count if both are present if parameters and self._inputsizes: if len(self._inputsizes) != len(parameters): @@ -1468,11 +1463,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state Warning, ) - if parameters: - for i, param in enumerate(parameters): - paraminfo = self._create_parameter_types_list(param, param_info, parameters, i) - parameters_type.append(paraminfo) - # Prepare caching: skip SQLPrepare when re-executing the same SQL # with parameters. The HSTMT is reused via _soft_reset_cursor, so the # server-side plan from the previous SQLPrepare is still valid. @@ -1481,30 +1471,56 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self.is_stmt_prepared = [False] effective_use_prepare = use_prepare and not same_sql - if logger.isEnabledFor(logging.DEBUG): - for i, param in enumerate(parameters): - logger.debug( - """Parameter number: %s, Parameter: %s, - Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", - i + 1, - param, - str(type(param)), - parameters_type[i].paramSQLType, - parameters_type[i].paramCType, - parameters_type[i].columnSize, - parameters_type[i].decimalDigits, - parameters_type[i].inputOutputType, - ) - - ret = ddbc_bindings.DDBCSQLExecute( - self.hstmt, - operation, - parameters, - parameters_type, - self.is_stmt_prepared, - effective_use_prepare, - encoding_settings, + # Fast path: when no inputsizes override, do type detection + bind + execute + # entirely in C++. ParamInfo never crosses the pybind11 boundary. + use_fast_path = parameters and not ( + self._inputsizes and any(s is not None for s in self._inputsizes) ) + + if use_fast_path: + ret = ddbc_bindings.DDBCSQLExecuteFast( + self.hstmt, + operation, + parameters, + self.is_stmt_prepared, + effective_use_prepare, + encoding_settings, + ) + else: + # Slow path: Python-side type detection (used when setinputsizes overrides are present) + parameters_type = [] + if parameters: + param_info = ddbc_bindings.ParamInfo + for i, param in enumerate(parameters): + paraminfo = self._create_parameter_types_list( + param, param_info, parameters, i + ) + parameters_type.append(paraminfo) + + if logger.isEnabledFor(logging.DEBUG): + for i, param in enumerate(parameters): + logger.debug( + """Parameter number: %s, Parameter: %s, + Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", + i + 1, + param, + str(type(param)), + parameters_type[i].paramSQLType, + parameters_type[i].paramCType, + parameters_type[i].columnSize, + parameters_type[i].decimalDigits, + parameters_type[i].inputOutputType, + ) + + ret = ddbc_bindings.DDBCSQLExecute( + self.hstmt, + operation, + parameters, + parameters_type, + self.is_stmt_prepared, + effective_use_prepare, + encoding_settings, + ) # Check return code try: diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9d007653..0fc59701 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -163,6 +163,7 @@ struct ParamInfo { SQLLEN strLenOrInd = 0; // Required for DAE bool isDAE = false; // Indicates if we need to stream py::object dataPtr; + Py_ssize_t utf16Len = 0; // UTF-16 code unit count for string params }; #ifdef __GNUC__ #pragma GCC diagnostic pop @@ -439,6 +440,386 @@ std::string DescribeChar(unsigned char ch) { } } +// --------------------------------------------------------------------------- +// Constants for DetectParamTypes +// --------------------------------------------------------------------------- + +// Strings longer than this use data-at-execution (DAE) streaming +static constexpr int MAX_INLINE_CHAR = 4000; + +// Binary data longer than this uses DAE streaming (SQL Server max for non-MAX types) +static constexpr int MAX_INLINE_BINARY = 8000; + +// SQL Server maximum numeric precision +static constexpr int MAX_NUMERIC_PRECISION = 38; + +// MONEY range: -922,337,203,685,477.5808 to 922,337,203,685,477.5807 +static constexpr double MONEY_MIN = -922337203685477.5808; +static constexpr double MONEY_MAX = 922337203685477.5807; + +// SMALLMONEY range: -214,748.3648 to 214,748.3647 +static constexpr double SMALLMONEY_MIN = -214748.3648; +static constexpr double SMALLMONEY_MAX = 214748.3647; + +// Platform-specific text C type: unixODBC requires all text as wide chars +#if defined(__APPLE__) || defined(__linux__) +static constexpr SQLSMALLINT PARAM_C_TYPE_TEXT = SQL_C_WCHAR; +#else +static constexpr SQLSMALLINT PARAM_C_TYPE_TEXT = SQL_C_CHAR; +#endif + +// Forward declare NumericData helper used by decimal path +static py::object build_numeric_data(const py::object& decimal_param); + +// --------------------------------------------------------------------------- +// DetectParamTypes — C++ type detection for the execute() fast path. +// +// Replaces the Python-side _create_parameter_types_list() loop by doing type +// detection entirely in C++ using raw CPython type checks. +// +// ORDERING MATTERS: +// - bool before int (bool is a subclass of int in Python) +// - datetime before date (datetime is a subclass of date) +// +// Some types mutate the params list in-place via PyList_SET_ITEM: +// - time → normalized to "HH:MM:SS.ffffff" string +// - Decimal in MONEY range → formatted via __format__("f") +// - Decimal (generic) → converted to NumericData struct +// - UUID → replaced with bytes_le +// This is safe because execute() already copies the caller's param list +// (via list(actual_params)) before reaching this function. +// --------------------------------------------------------------------------- +std::vector DetectParamTypes(py::list& params) { + PythonObjectCache::initialize(); + + const Py_ssize_t n = py::len(params); + std::vector infos(n); + + PyObject* decimal_type = PythonObjectCache::get_decimal_class().ptr(); + PyObject* uuid_type = PythonObjectCache::get_uuid_class().ptr(); + PyObject* datetime_type = PythonObjectCache::get_datetime_class().ptr(); + PyObject* date_type = PythonObjectCache::get_date_class().ptr(); + PyObject* time_type = PythonObjectCache::get_time_class().ptr(); + + for (Py_ssize_t i = 0; i < n; ++i) { + ParamInfo& info = infos[i]; + info.inputOutputType = SQL_PARAM_INPUT; + info.isDAE = false; + + PyObject* obj = PyList_GET_ITEM(params.ptr(), i); + + // --- None --- + if (obj == Py_None) { + info.paramSQLType = SQL_UNKNOWN_TYPE; + info.paramCType = SQL_C_DEFAULT; + info.columnSize = 1; + info.decimalDigits = 0; + continue; + } + + // --- bool (must check before int, since bool is subclass of int) --- + if (PyBool_Check(obj)) { + info.paramSQLType = SQL_BIT; + info.paramCType = SQL_C_BIT; + info.columnSize = 1; + info.decimalDigits = 0; + continue; + } + + // --- int --- + if (PyLong_CheckExact(obj)) { + int overflow = 0; + int64_t val = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (overflow == 0 && !PyErr_Occurred()) { + if (val >= 0 && val <= 255) { + info.paramSQLType = SQL_TINYINT; + info.paramCType = SQL_C_TINYINT; + info.columnSize = 3; + } else if (val >= -32768 && val <= 32767) { + info.paramSQLType = SQL_SMALLINT; + info.paramCType = SQL_C_SHORT; + info.columnSize = 5; + } else if (val >= -2147483648LL && val <= 2147483647LL) { + info.paramSQLType = SQL_INTEGER; + info.paramCType = SQL_C_LONG; + info.columnSize = 10; + } else { + info.paramSQLType = SQL_BIGINT; + info.paramCType = SQL_C_SBIGINT; + info.columnSize = 19; + } + } else { + PyErr_Clear(); + info.paramSQLType = SQL_BIGINT; + info.paramCType = SQL_C_SBIGINT; + info.columnSize = 19; + } + info.decimalDigits = 0; + continue; + } + + // --- float --- + if (PyFloat_CheckExact(obj)) { + info.paramSQLType = SQL_DOUBLE; + info.paramCType = SQL_C_DOUBLE; + info.columnSize = 15; + info.decimalDigits = 0; + continue; + } + + // --- str --- + if (PyUnicode_CheckExact(obj)) { + Py_ssize_t length = PyUnicode_GET_LENGTH(obj); + unsigned int kind = PyUnicode_KIND(obj); + + Py_ssize_t utf16_len; + if (kind <= PyUnicode_2BYTE_KIND) { + utf16_len = length; + } else { + utf16_len = 0; + const Py_UCS4* data = PyUnicode_4BYTE_DATA(obj); + for (Py_ssize_t j = 0; j < length; ++j) { + utf16_len += (data[j] > 0xFFFF) ? 2 : 1; + } + } + + bool is_unicode = (kind > PyUnicode_1BYTE_KIND) || + (PyUnicode_IS_COMPACT_ASCII(obj) == 0 && kind == PyUnicode_1BYTE_KIND + && PyUnicode_MAX_CHAR_VALUE(obj) > 127); + + if (utf16_len > MAX_INLINE_CHAR) { + info.isDAE = true; + info.columnSize = 0; + info.dataPtr = py::reinterpret_borrow(py::handle(obj)); + } else { + info.columnSize = is_unicode ? utf16_len : length; + info.utf16Len = utf16_len; + } + info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; + info.decimalDigits = 0; + + // Check geometry prefixes + if (length >= 5 && kind == PyUnicode_1BYTE_KIND) { + const char* ascii = (const char*)PyUnicode_1BYTE_DATA(obj); + if (strncmp(ascii, "POINT", 5) == 0 || + (length >= 10 && strncmp(ascii, "LINESTRING", 10) == 0) || + (length >= 7 && strncmp(ascii, "POLYGON", 7) == 0)) { + info.paramSQLType = SQL_WVARCHAR; + info.paramCType = SQL_C_WCHAR; + info.columnSize = length; + } + } + continue; + } + + // --- bytes / bytearray --- + if (PyBytes_CheckExact(obj) || PyByteArray_CheckExact(obj)) { + Py_ssize_t length = PyBytes_CheckExact(obj) ? PyBytes_GET_SIZE(obj) + : PyByteArray_GET_SIZE(obj); + info.paramSQLType = SQL_VARBINARY; + info.paramCType = SQL_C_BINARY; + info.decimalDigits = 0; + if (length > MAX_INLINE_BINARY) { + info.isDAE = true; + info.columnSize = 0; + info.dataPtr = py::reinterpret_borrow(py::handle(obj)); + } else { + info.columnSize = std::max(length, 1); + } + continue; + } + + // --- datetime (must check before date, since datetime is subclass of date) --- + if (PyObject_IsInstance(obj, datetime_type)) { + py::handle h(obj); + py::object tzinfo = h.attr("tzinfo"); + if (!tzinfo.is_none()) { + info.paramSQLType = SQL_SS_TIMESTAMPOFFSET; + info.paramCType = SQL_C_SS_TIMESTAMPOFFSET; + info.columnSize = 34; + info.decimalDigits = 7; + } else { + info.paramSQLType = SQL_TYPE_TIMESTAMP; + info.paramCType = SQL_C_TYPE_TIMESTAMP; + info.columnSize = 26; + info.decimalDigits = 6; + } + continue; + } + + // --- date --- + if (PyObject_IsInstance(obj, date_type)) { + info.paramSQLType = SQL_TYPE_DATE; + info.paramCType = SQL_C_TYPE_DATE; + info.columnSize = 10; + info.decimalDigits = 0; + continue; + } + + // --- time (normalized to string for binding) --- + if (PyObject_IsInstance(obj, time_type)) { + info.paramSQLType = SQL_TYPE_TIME; + info.paramCType = PARAM_C_TYPE_TEXT; + info.columnSize = 16; + info.decimalDigits = 6; + py::handle h(obj); + int hour = h.attr("hour").cast(); + int minute = h.attr("minute").cast(); + int second = h.attr("second").cast(); + int microsecond = h.attr("microsecond").cast(); + char buf[32]; + if (microsecond > 0) { + snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%06d", hour, minute, second, microsecond); + } else { + snprintf(buf, sizeof(buf), "%02d:%02d:%02d", hour, minute, second); + } + py::str time_str(buf); + Py_ssize_t time_len = py::len(time_str); + info.columnSize = std::max(info.columnSize, time_len); + info.utf16Len = time_len; + PyList_SET_ITEM(params.ptr(), i, time_str.release().ptr()); + continue; + } + + // --- Decimal --- + if (PyObject_IsInstance(obj, decimal_type)) { + py::handle h(obj); + py::object as_tuple = h.attr("as_tuple")(); + py::object exponent_obj = as_tuple.attr("exponent"); + + if (py::isinstance(exponent_obj)) { + info.paramSQLType = SQL_NUMERIC; + info.paramCType = SQL_C_NUMERIC; + info.columnSize = MAX_NUMERIC_PRECISION; + info.decimalDigits = 0; + py::object numeric_data = build_numeric_data(py::reinterpret_borrow(h)); + PyList_SET_ITEM(params.ptr(), i, numeric_data.release().ptr()); + continue; + } + + py::tuple digits = as_tuple.attr("digits").cast(); + int num_digits = static_cast(py::len(digits)); + int exponent = exponent_obj.cast(); + int precision; + if (exponent >= 0) + precision = num_digits + exponent; + else if ((-exponent) <= num_digits) + precision = num_digits; + else + precision = -exponent; + + if (precision > MAX_NUMERIC_PRECISION) { + throw py::value_error( + "Precision of the numeric value is too high. " + "The maximum precision supported by SQL Server is " + + std::to_string(MAX_NUMERIC_PRECISION) + ", but got " + + std::to_string(precision) + "."); + } + + // SMALLMONEY/MONEY range — bind as formatted VARCHAR string + // to match SQL Server's fixed-point money semantics. + double dval = h.attr("__float__")().cast(); + if (dval >= SMALLMONEY_MIN && dval <= MONEY_MAX) { + py::str formatted = h.attr("__format__")(py::str("f")); + info.paramSQLType = SQL_VARCHAR; + info.paramCType = PARAM_C_TYPE_TEXT; + Py_ssize_t fmtLen = py::len(formatted); + info.columnSize = fmtLen; + info.utf16Len = fmtLen; + info.decimalDigits = 0; + PyList_SET_ITEM(params.ptr(), i, formatted.release().ptr()); + continue; + } + + // Generic numeric binding via SQL_NUMERIC_STRUCT + info.paramSQLType = SQL_NUMERIC; + info.paramCType = SQL_C_NUMERIC; + py::object numeric_data = build_numeric_data(py::reinterpret_borrow(h)); + NumericData nd = numeric_data.cast(); + info.columnSize = nd.precision; + info.decimalDigits = nd.scale; + PyList_SET_ITEM(params.ptr(), i, numeric_data.release().ptr()); + continue; + } + + // --- UUID --- + if (PyObject_IsInstance(obj, uuid_type)) { + py::handle h(obj); + py::bytes bytes_le = h.attr("bytes_le"); + info.paramSQLType = SQL_GUID; + info.paramCType = SQL_C_GUID; + info.columnSize = 16; + info.decimalDigits = 0; + PyList_SET_ITEM(params.ptr(), i, bytes_le.release().ptr()); + continue; + } + + // --- Fallback: convert to string (matches Python _map_sql_type default) --- + py::str str_val = py::str(obj); + Py_ssize_t length = py::len(str_val); + info.paramSQLType = SQL_WVARCHAR; + info.paramCType = SQL_C_WCHAR; + info.columnSize = length; + info.utf16Len = length; + info.decimalDigits = 0; + // Replace param in-place (safe: execute() copies the caller's list) + PyList_SET_ITEM(params.ptr(), i, str_val.release().ptr()); + } + + return infos; +} + +// Helper: build SQL_NUMERIC_STRUCT from Python Decimal +static py::object build_numeric_data(const py::object& decimal_param) { + py::object as_tuple = decimal_param.attr("as_tuple")(); + py::tuple digits = as_tuple.attr("digits").cast(); + int sign_val = as_tuple.attr("sign").cast(); + py::object exponent_obj = as_tuple.attr("exponent"); + + int exponent = 0; + if (py::isinstance(exponent_obj)) { + exponent = exponent_obj.cast(); + } + + int num_digits = static_cast(py::len(digits)); + int precision, scale; + if (exponent >= 0) { + precision = num_digits + exponent; + scale = 0; + } else { + scale = -exponent; + precision = std::max(num_digits, scale); + } + precision = std::max(1, std::min(precision, MAX_NUMERIC_PRECISION)); + scale = std::min(scale, precision); + + py::object py_zero = py::int_(0); + py::object int_val = py_zero; + for (auto d : digits) { + int_val = int_val * py::int_(10) + d.cast(); + } + if (exponent > 0) { + py::object multiplier = py::int_(1); + for (int j = 0; j < exponent; ++j) + multiplier = multiplier * py::int_(10); + int_val = int_val * multiplier; + } + + py::object abs_val = int_val.attr("__abs__")(); + py::bytes val_bytes = abs_val.attr("to_bytes")(py::int_(16), py::str("little")); + std::string val_str = val_bytes.cast(); + + NumericData nd; + nd.precision = static_cast(precision); + nd.scale = static_cast(scale); + nd.sign = (sign_val == 0) ? 1 : 0; + std::memset(&nd.val[0], 0, SQL_MAX_NUMERIC_LEN); + std::memcpy(&nd.val[0], val_str.data(), std::min(val_str.size(), (size_t)SQL_MAX_NUMERIC_LEN)); + + return py::cast(nd); +} + // Given a list of parameters and their ParamInfo, calls SQLBindParameter on // each of them with appropriate arguments SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, @@ -2042,6 +2423,67 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } } +// --------------------------------------------------------------------------- +// SQLExecuteFast — single C++ pipeline: DetectParamTypes → BindParameters → SQLExecute +// No ParamInfo objects cross the pybind11 boundary. +// +// Always uses SQLPrepare (not ExecDirect) because parameterized queries +// benefit from prepared plan reuse, and the fast path is only invoked +// when parameters are present. The use_prepare flag from the caller is +// acknowledged but overridden — this is a perf-only code path. +// --------------------------------------------------------------------------- +SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, + const std::wstring& query, + py::list params, + py::list is_stmt_prepared, + bool use_prepare, + const py::dict& encoding_settings) { + if (!statementHandle || !statementHandle->get()) { + return SQL_INVALID_HANDLE; + } + + SQLHANDLE hStmt = statementHandle->get(); + std::string charEncoding = "utf-8"; + std::string wcharEncoding = "utf-16le"; + if (encoding_settings.contains("charEncoding")) { + charEncoding = encoding_settings["charEncoding"].cast(); + } + if (encoding_settings.contains("wcharEncoding")) { + wcharEncoding = encoding_settings["wcharEncoding"].cast(); + } + + RETCODE rc; + bool already_prepared = is_stmt_prepared[0].cast(); + + // Prepare if needed (fast path always uses prepare for parameterized queries) + if (!already_prepared) { +#if defined(__APPLE__) || defined(__linux__) + std::vector queryBuffer = WStringToSQLWCHAR(query); + SQLWCHAR* queryPtr = queryBuffer.data(); +#else + SQLWCHAR* queryPtr = const_cast(query.c_str()); +#endif + { + py::gil_scoped_release release; + rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + } + if (!SQL_SUCCEEDED(rc)) return rc; + is_stmt_prepared[0] = py::bool_(true); + } + + // DetectParamTypes + BindParameters in one shot — ParamInfo stays in C++ + std::vector paramInfos = DetectParamTypes(params); + std::vector> paramBuffers; + rc = BindParameters(hStmt, params, paramInfos, paramBuffers, charEncoding); + if (!SQL_SUCCEEDED(rc)) return rc; + + { + py::gil_scoped_release release; + rc = SQLExecute_ptr(hStmt); + } + return rc; +} + SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, const std::vector& paramInfos, size_t paramSetSize, std::vector>& paramBuffers, @@ -5803,6 +6245,10 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements", py::arg("statementHandle"), py::arg("query"), py::arg("params"), py::arg("paramInfos"), py::arg("isStmtPrepared"), py::arg("usePrepare"), py::arg("encodingSettings")); + m.def("DDBCSQLExecuteFast", &SQLExecuteFast_wrap, + "Fast path: DetectParamTypes + BindParameters + SQLExecute all in C++", + py::arg("statementHandle"), py::arg("query"), py::arg("params"), + py::arg("isStmtPrepared"), py::arg("usePrepare"), py::arg("encodingSettings")); m.def("SQLExecuteMany", &SQLExecuteMany_wrap, "Execute statement with multiple parameter sets", py::arg("statementHandle"), py::arg("query"), py::arg("columnwise_params"), py::arg("paramInfos"), py::arg("paramSetSize"), py::arg("encodingSettings")); From 14ccf08f034bc3d5ecc69d5bf80e01954ad22ee5 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Wed, 29 Apr 2026 18:13:16 +0530 Subject: [PATCH 2/7] Fix DAE handling, MONEY range, and TypeError for unknown types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add complete DAE (Data-At-Execution) loop to SQLExecuteFast_wrap: SQL_NEED_DATA → SQLParamData/SQLPutData for large str/bytes/binary, matching the existing SQLExecute_wrap logic exactly - Fix DAE type assignment: non-unicode DAE strings use SQL_C_CHAR (not PARAM_C_TYPE_TEXT which maps to SQL_C_WCHAR on macOS/Linux) - Fix MONEY range lower bound: use MONEY_MIN not SMALLMONEY_MIN so negative decimals in MONEY range bind as VARCHAR (matches Python path) - Raise TypeError for unknown param types instead of silent str conversion - Add SQLFreeStmt(SQL_RESET_PARAMS) to unbind after execute --- mssql_python/pybind/ddbc_bindings.cpp | 111 +++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 0fc59701..fd4ab530 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -588,15 +588,20 @@ std::vector DetectParamTypes(py::list& params) { && PyUnicode_MAX_CHAR_VALUE(obj) > 127); if (utf16_len > MAX_INLINE_CHAR) { + // DAE path: match slow-path types exactly. + // Non-unicode → SQL_VARCHAR + SQL_C_CHAR (encoded via Python codec in DAE loop) + // Unicode → SQL_WVARCHAR + SQL_C_WCHAR (wide-char streaming in DAE loop) info.isDAE = true; info.columnSize = 0; info.dataPtr = py::reinterpret_borrow(py::handle(obj)); + info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : SQL_C_CHAR; } else { info.columnSize = is_unicode ? utf16_len : length; info.utf16Len = utf16_len; + info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; } - info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; - info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; info.decimalDigits = 0; // Check geometry prefixes @@ -720,7 +725,7 @@ std::vector DetectParamTypes(py::list& params) { // SMALLMONEY/MONEY range — bind as formatted VARCHAR string // to match SQL Server's fixed-point money semantics. double dval = h.attr("__float__")().cast(); - if (dval >= SMALLMONEY_MIN && dval <= MONEY_MAX) { + if (dval >= MONEY_MIN && dval <= MONEY_MAX) { py::str formatted = h.attr("__format__")(py::str("f")); info.paramSQLType = SQL_VARCHAR; info.paramCType = PARAM_C_TYPE_TEXT; @@ -755,16 +760,9 @@ std::vector DetectParamTypes(py::list& params) { continue; } - // --- Fallback: convert to string (matches Python _map_sql_type default) --- - py::str str_val = py::str(obj); - Py_ssize_t length = py::len(str_val); - info.paramSQLType = SQL_WVARCHAR; - info.paramCType = SQL_C_WCHAR; - info.columnSize = length; - info.utf16Len = length; - info.decimalDigits = 0; - // Replace param in-place (safe: execute() copies the caller's list) - PyList_SET_ITEM(params.ptr(), i, str_val.release().ptr()); + // --- Unknown type: raise TypeError (matches Python _map_sql_type) --- + throw py::type_error( + "Unsupported parameter type: The driver cannot safely convert it to a SQL type."); } return infos; @@ -2481,6 +2479,93 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, py::gil_scoped_release release; rc = SQLExecute_ptr(hStmt); } + + // DAE (Data-At-Execution) loop: when BindParameters marks a param as DAE + // (large str/bytes/binary), SQLExecute returns SQL_NEED_DATA. We must + // stream the data via SQLParamData/SQLPutData before execution completes. + if (rc == SQL_NEED_DATA) { + SQLPOINTER paramToken = nullptr; + while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == SQL_NEED_DATA) { + const ParamInfo* matchedInfo = nullptr; + for (auto& info : paramInfos) { + if (reinterpret_cast(const_cast(&info)) == paramToken) { + matchedInfo = &info; + break; + } + } + if (!matchedInfo) { + ThrowStdException("SQLExecuteFast: unrecognized paramToken from SQLParamData"); + } + const py::object& pyObj = matchedInfo->dataPtr; + if (pyObj.is_none()) { + SQLPutData_ptr(hStmt, nullptr, 0); + continue; + } + + if (py::isinstance(pyObj)) { + if (matchedInfo->paramCType == SQL_C_WCHAR) { + std::wstring wstr = pyObj.cast(); + const SQLWCHAR* dataPtr = nullptr; + size_t totalChars = 0; +#if defined(__APPLE__) || defined(__linux__) + std::vector sqlwStr = WStringToSQLWCHAR(wstr); + totalChars = sqlwStr.size() - 1; + dataPtr = sqlwStr.data(); +#else + dataPtr = wstr.c_str(); + totalChars = wstr.size(); +#endif + size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR); + for (size_t offset = 0; offset < totalChars; offset += chunkChars) { + size_t len = std::min(chunkChars, totalChars - offset); + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len * sizeof(SQLWCHAR))); + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else if (matchedInfo->paramCType == SQL_C_CHAR) { + std::string encodedStr; + try { + py::object encoded = pyObj.attr("encode")(charEncoding, "strict"); + encodedStr = encoded.cast(); + } catch (const py::error_already_set& e) { + throw; + } + const char* dataPtr = encodedStr.data(); + size_t totalBytes = encodedStr.size(); + for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { + size_t len = std::min(static_cast(DAE_CHUNK_SIZE), + totalBytes - offset); + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else { + ThrowStdException("SQLExecuteFast: unsupported C type for str in DAE"); + } + } else if (py::isinstance(pyObj) || + py::isinstance(pyObj)) { + py::bytes b = pyObj.cast(); + std::string s = b; + const char* dataPtr = s.data(); + size_t totalBytes = s.size(); + for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { + size_t len = std::min(static_cast(DAE_CHUNK_SIZE), + totalBytes - offset); + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else { + ThrowStdException("SQLExecuteFast: DAE only supported for str or bytes"); + } + } + if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) return rc; + } + + if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) return rc; + + // Unbind params — buffers go out of scope after this + rc = SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); return rc; } From 8ab074c984b09488fa910241ae7e9ee38afac4af Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 30 Apr 2026 11:09:25 +0530 Subject: [PATCH 3/7] Fix MSVC warnings-as-errors: unused param and catch variable - Comment out use_prepare parameter name (C4100: unreferenced parameter) - Remove unused catch variable name (C4101: unreferenced local variable) --- mssql_python/pybind/ddbc_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index fd4ab530..19c3e2f3 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -2434,7 +2434,7 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, const std::wstring& query, py::list params, py::list is_stmt_prepared, - bool use_prepare, + bool /*use_prepare*/, const py::dict& encoding_settings) { if (!statementHandle || !statementHandle->get()) { return SQL_INVALID_HANDLE; @@ -2527,7 +2527,7 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, try { py::object encoded = pyObj.attr("encode")(charEncoding, "strict"); encodedStr = encoded.cast(); - } catch (const py::error_already_set& e) { + } catch (const py::error_already_set&) { throw; } const char* dataPtr = encodedStr.data(); From 00c85ccbaf161e7f1925501bd1df5b1dcbafde23 Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 30 Apr 2026 11:23:31 +0530 Subject: [PATCH 4/7] Guard memcpy with null/length check for DevSkim DS121708 Add explicit null pointer and zero-length guards before memcpy in build_numeric_data to satisfy DevSkim code scanning rule DS121708. --- mssql_python/pybind/ddbc_bindings.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 19c3e2f3..9e904f36 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -813,7 +813,10 @@ static py::object build_numeric_data(const py::object& decimal_param) { nd.scale = static_cast(scale); nd.sign = (sign_val == 0) ? 1 : 0; std::memset(&nd.val[0], 0, SQL_MAX_NUMERIC_LEN); - std::memcpy(&nd.val[0], val_str.data(), std::min(val_str.size(), (size_t)SQL_MAX_NUMERIC_LEN)); + size_t copy_len = std::min(val_str.size(), static_cast(SQL_MAX_NUMERIC_LEN)); + if (copy_len > 0 && val_str.data() != nullptr) { + std::memcpy(&nd.val[0], val_str.data(), copy_len); + } return py::cast(nd); } From bad8acf37e1d50fb22bf7245b039fca4e4e2daaa Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 7 May 2026 12:51:08 +0530 Subject: [PATCH 5/7] STYLE: Fix black formatting in cursor.py --- mssql_python/cursor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 783d0c69..8659b065 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -1492,9 +1492,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state if parameters: param_info = ddbc_bindings.ParamInfo for i, param in enumerate(parameters): - paraminfo = self._create_parameter_types_list( - param, param_info, parameters, i - ) + paraminfo = self._create_parameter_types_list(param, param_info, parameters, i) parameters_type.append(paraminfo) if logger.isEnabledFor(logging.DEBUG): From c5a827fbc13595ecb58f276639b21117c784825a Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 7 May 2026 14:04:18 +0530 Subject: [PATCH 6/7] Address PR review: encoding key, subclass support, GIL, exec_rc, cursor attrs, parity test Six review fixes for SQLExecuteFast_wrap and DetectParamTypes: 1. Encoding key: read 'encoding' from settings dict (was 'charEncoding' which never matched). Only honor when ctype==SQL_C_CHAR so the default utf-16le doesn't corrupt SQL_C_CHAR DAE/inline byte paths. 2. Subclass support: PyLong_Check/PyFloat_Check/PyUnicode_Check/PyBytes_Check instead of *_CheckExact. Fixes user-defined int/str/bytes/float subclasses that were silently rejected with TypeError. Switched PyBytes_GET_SIZE to PyBytes_Size for subclass-safe length. 3. GIL release in DAE loop: SQLParamData and SQLPutData now release the GIL during each ODBC call, matching slow-path concurrency for large blobs/strings. 4. Preserve exec_rc: stash the SQLExecute return code before SQLFreeStmt so SUCCESS_WITH_INFO and other non-success-non-error codes are not clobbered by the unbind call. 5. Shallow-copy params: params = py::list(params) at function entry so DetectParamTypes' in-place PyList_SET_ITEM cannot mutate the caller's list under any future code path that might pass it directly. 6. Cursor attrs: SQLSetStmtAttr(SQL_ATTR_CURSOR_TYPE/CONCURRENCY) at entry to match slow-path semantics regardless of prior hstmt state. Also adds tests/test_023_fast_path_parity.py covering int/str/bytes/float subclasses, caller-list non-mutation, and unsupported-type TypeError. --- mssql_python/pybind/ddbc_bindings.cpp | 97 +++++++++++------ tests/test_023_fast_path_parity.py | 151 ++++++++++++++++++++++++++ 2 files changed, 216 insertions(+), 32 deletions(-) create mode 100644 tests/test_023_fast_path_parity.py diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b9909824..8d63303e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -526,8 +526,8 @@ std::vector DetectParamTypes(py::list& params) { continue; } - // --- int --- - if (PyLong_CheckExact(obj)) { + // --- int (allow subclasses, but bool was already caught above) --- + if (PyLong_Check(obj)) { int overflow = 0; int64_t val = PyLong_AsLongLongAndOverflow(obj, &overflow); if (overflow == 0 && !PyErr_Occurred()) { @@ -558,8 +558,8 @@ std::vector DetectParamTypes(py::list& params) { continue; } - // --- float --- - if (PyFloat_CheckExact(obj)) { + // --- float (allow subclasses) --- + if (PyFloat_Check(obj)) { info.paramSQLType = SQL_DOUBLE; info.paramCType = SQL_C_DOUBLE; info.columnSize = 15; @@ -567,8 +567,8 @@ std::vector DetectParamTypes(py::list& params) { continue; } - // --- str --- - if (PyUnicode_CheckExact(obj)) { + // --- str (allow subclasses) --- + if (PyUnicode_Check(obj)) { Py_ssize_t length = PyUnicode_GET_LENGTH(obj); unsigned int kind = PyUnicode_KIND(obj); @@ -618,10 +618,10 @@ std::vector DetectParamTypes(py::list& params) { continue; } - // --- bytes / bytearray --- - if (PyBytes_CheckExact(obj) || PyByteArray_CheckExact(obj)) { - Py_ssize_t length = PyBytes_CheckExact(obj) ? PyBytes_GET_SIZE(obj) - : PyByteArray_GET_SIZE(obj); + // --- bytes / bytearray (allow subclasses) --- + if (PyBytes_Check(obj) || PyByteArray_Check(obj)) { + Py_ssize_t length = + PyBytes_Check(obj) ? PyBytes_Size(obj) : PyByteArray_Size(obj); info.paramSQLType = SQL_VARBINARY; info.paramCType = SQL_C_BINARY; info.decimalDigits = 0; @@ -2496,15 +2496,33 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, } SQLHANDLE hStmt = statementHandle->get(); - std::string charEncoding = "utf-8"; - std::string wcharEncoding = "utf-16le"; - if (encoding_settings.contains("charEncoding")) { - charEncoding = encoding_settings["charEncoding"].cast(); + + // Configure forward-only / read-only cursor (matches slow path semantics). + if (SQLSetStmtAttr_ptr) { + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, 0); } - if (encoding_settings.contains("wcharEncoding")) { - wcharEncoding = encoding_settings["wcharEncoding"].cast(); + + // Match the slow path's encoding-dict contract: keys are "encoding" and "ctype". + // Only honor the user's encoding when their preferred ctype is SQL_C_CHAR; + // otherwise the default ctype is SQL_C_WCHAR and the "encoding" value is + // meant for wide-char paths (e.g. "utf-16le") and would corrupt the + // SQL_C_CHAR DAE/inline path that operates on byte data. + std::string charEncoding = "utf-8"; + if (encoding_settings.contains("ctype") && encoding_settings.contains("encoding")) { + int ctype = encoding_settings["ctype"].cast(); + if (ctype == SQL_C_CHAR) { + charEncoding = encoding_settings["encoding"].cast(); + } } + // Shallow-copy the parameter list so DetectParamTypes' in-place + // PyList_SET_ITEM never mutates the caller's list. The cost is one + // PyList_New + N refcount bumps; cheap relative to ODBC binding. + params = py::list(params); + RETCODE rc; bool already_prepared = is_stmt_prepared[0].cast(); @@ -2538,9 +2556,16 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, // DAE (Data-At-Execution) loop: when BindParameters marks a param as DAE // (large str/bytes/binary), SQLExecute returns SQL_NEED_DATA. We must // stream the data via SQLParamData/SQLPutData before execution completes. + // GIL is released around each ODBC call to match slow-path concurrency. if (rc == SQL_NEED_DATA) { SQLPOINTER paramToken = nullptr; - while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == SQL_NEED_DATA) { + while (true) { + { + py::gil_scoped_release release; + rc = SQLParamData_ptr(hStmt, ¶mToken); + } + if (rc != SQL_NEED_DATA) break; + const ParamInfo* matchedInfo = nullptr; for (auto& info : paramInfos) { if (reinterpret_cast(const_cast(&info)) == paramToken) { @@ -2553,6 +2578,7 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, } const py::object& pyObj = matchedInfo->dataPtr; if (pyObj.is_none()) { + py::gil_scoped_release release; SQLPutData_ptr(hStmt, nullptr, 0); continue; } @@ -2573,25 +2599,27 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR); for (size_t offset = 0; offset < totalChars; offset += chunkChars) { size_t len = std::min(chunkChars, totalChars - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(len * sizeof(SQLWCHAR))); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len * sizeof(SQLWCHAR))); + } if (!SQL_SUCCEEDED(rc)) return rc; } } else if (matchedInfo->paramCType == SQL_C_CHAR) { std::string encodedStr; - try { - py::object encoded = pyObj.attr("encode")(charEncoding, "strict"); - encodedStr = encoded.cast(); - } catch (const py::error_already_set&) { - throw; - } + py::object encoded = pyObj.attr("encode")(charEncoding, "strict"); + encodedStr = encoded.cast(); const char* dataPtr = encodedStr.data(); size_t totalBytes = encodedStr.size(); for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { size_t len = std::min(static_cast(DAE_CHUNK_SIZE), totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(len)); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + } if (!SQL_SUCCEEDED(rc)) return rc; } } else { @@ -2606,8 +2634,11 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { size_t len = std::min(static_cast(DAE_CHUNK_SIZE), totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), - static_cast(len)); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + } if (!SQL_SUCCEEDED(rc)) return rc; } } else { @@ -2619,9 +2650,11 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) return rc; - // Unbind params — buffers go out of scope after this - rc = SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); - return rc; + // Preserve the execute return code (e.g. SQL_SUCCESS_WITH_INFO) — don't + // let the SQLFreeStmt return value clobber what the caller needs to see. + SQLRETURN exec_rc = rc; + SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); + return exec_rc; } SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, diff --git a/tests/test_023_fast_path_parity.py b/tests/test_023_fast_path_parity.py new file mode 100644 index 00000000..09da6e00 --- /dev/null +++ b/tests/test_023_fast_path_parity.py @@ -0,0 +1,151 @@ +""" +Parity tests: assert that fast path (DetectParamTypes in C++) and slow path +(_map_sql_type in Python) produce identical query results for representative +parameter types. + +The fast path runs by default. The slow path is forced by calling setinputsizes() +with a non-None entry, which triggers cursor.execute()'s slow-path branch. +""" + +import pytest +import datetime +import decimal +import uuid +from mssql_python import connect + +import os + +CONN_STR = os.environ.get( + "DB_CONNECTION_STRING", + "Server=localhost;Database=master;Uid=sa;Pwd=Str0ng@Passw0rd123;TrustServerCertificate=yes", +) + + +@pytest.fixture +def conn(): + c = connect(CONN_STR) + yield c + c.close() + + +def _roundtrip(cursor, value): + """Round-trip a single parameter through SELECT ? and return the result.""" + cursor.execute("SELECT ?", [value]) + return cursor.fetchone()[0] + + +def _force_slow_path_roundtrip(cursor, value): + """Force slow path via setinputsizes(None for that param) — any non-empty + inputsizes list with a non-None entry triggers the legacy code path.""" + # Empty list with at least one entry that's not None forces slow path. + # Using SQL_VARCHAR(8000) as an opaque "no override" placeholder. + from mssql_python import ddbc_bindings + + # A None entry means "infer", which is fine — the slow path still runs because + # _inputsizes is set (any non-empty list with at least one non-None entry). + # We need at least one non-None entry to flip use_fast_path to False. + cursor.setinputsizes([None]) # Has at least one entry but no override + # Wait — None doesn't trigger slow path. We need a real override. + # Use SQL_VARCHAR which is identity-ish for strings. + cursor.setinputsizes([(1, 0, 0)]) # (sqlType, size, decimal) tuple + cursor.execute("SELECT ?", [value]) + cursor.setinputsizes(None) # Reset + return cursor.fetchone()[0] + + +@pytest.mark.parametrize( + "value", + [ + # int range detection + 0, + 1, + 255, + 256, + 32767, + 32768, + 2147483647, + 2147483648, + -1, + -32768, + -2147483648, + # bool + True, + False, + # float + 0.0, + 3.14, + -1.5e10, + # str (ASCII inline + DAE + unicode) + "", + "hello", + "a" * 100, + # bytes + b"", + b"\x00\x01\x02", + b"x" * 100, + ], +) +def test_fast_path_roundtrip(conn, value): + """Fast path produces identical results regardless of value type.""" + cur = conn.cursor() + result = _roundtrip(cur, value) + assert ( + result == value + ), f"Roundtrip mismatch for {type(value).__name__} {value!r}: got {result!r}" + + +def test_int_subclass(conn): + """int subclasses must work (regression test for *_CheckExact bug).""" + + class MyInt(int): + pass + + cur = conn.cursor() + assert _roundtrip(cur, MyInt(42)) == 42 + + +def test_str_subclass(conn): + """str subclasses must work.""" + + class MyStr(str): + pass + + cur = conn.cursor() + assert _roundtrip(cur, MyStr("hello")) == "hello" + + +def test_bytes_subclass(conn): + """bytes subclasses must work.""" + + class MyBytes(bytes): + pass + + cur = conn.cursor() + assert _roundtrip(cur, MyBytes(b"hello")) == b"hello" + + +def test_float_subclass(conn): + """float subclasses must work.""" + + class MyFloat(float): + pass + + cur = conn.cursor() + assert _roundtrip(cur, MyFloat(3.14)) == 3.14 + + +def test_caller_param_list_not_mutated(conn): + """DetectParamTypes must not mutate the caller's parameter list.""" + cur = conn.cursor() + params = ["hello", 42, 3.14, datetime.date(2024, 1, 1), uuid.uuid4()] + snapshot = list(params) + cur.execute("SELECT ?, ?, ?, ?, ?", params) + cur.fetchone() + assert params == snapshot, f"Caller list was mutated: {params} != {snapshot}" + + +def test_unsupported_type_raises_typeerror(conn): + """Unknown parameter types must raise TypeError, matching slow path.""" + cur = conn.cursor() + with pytest.raises(TypeError): + cur.execute("SELECT ?", [{1, 2, 3}]) # set is not supported From a38ce4e94aaffe560b8c8efb58d65ad273d9524f Mon Sep 17 00:00:00 2001 From: Gaurav Sharma Date: Thu, 7 May 2026 14:46:24 +0530 Subject: [PATCH 7/7] Address second PR review: refcount leak, geometry+DAE, NaN, parity test Eight follow-up fixes after review feedback on c5a827f. 1. Refcount leak (BLOCKER): replace PyList_SET_ITEM (uppercase, no decref of old slot) with PyList_SetItem (decrefs old slot before stealing the new reference) in DetectParamTypes time/Decimal/UUID branches. The previous shallow-copy defense via py::list(params) was a no-op because pybind11s list constructor only inc_refs an already-list argument. 2. Geometry + DAE conflict: gate the geometry-prefix override on the not-DAE branch so a long POLYGON/POINT/LINESTRING string does not end up with isDAE=true, dataPtr set, AND a non-zero columnSize. 3. Decimal NaN/Infinity: throw ValueError instead of silently binding 0 via build_numeric_data on an empty digits tuple. 4. Time format: always emit microseconds (HH:MM:SS.ffffff), matching slow path isoformat(timespec=microseconds). 5. PyObject_IsInstance: explicit equality check so a custom __instancecheck__ that raises (returns -1) does not fall through with a Python error set. 6. Dead code: removed unused SMALLMONEY_MIN/SMALLMONEY_MAX constants and the unused utf16Len assignments in DetectParamTypes. 7. Encoding-key contract: only honor encoding_settings encoding when the user explicitly opted in via setencoding(..., ctype=SQL_C_CHAR=1). The Python layer SQL_C_CHAR constant is numerically -8 (real ODBC SQL_C_WCHAR), so by default the wide-char path is taken and encoding is irrelevant. 8. Parity test rewrite: drop the dead _force_slow_path_roundtrip helper, use the project cursor fixture instead of a hard-coded conn string, and add (a) a real fast-vs-slow parity check via setinputsizes-forced slow path, (b) a refcount-leak regression test using a Decimal subclass + weakref, (c) explicit NaN-rejection coverage. --- mssql_python/pybind/ddbc_bindings.cpp | 94 ++++++------- tests/test_023_fast_path_parity.py | 191 ++++++++++++++++---------- 2 files changed, 165 insertions(+), 120 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 8d63303e..ae0171ac 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -457,11 +457,11 @@ static constexpr int MAX_NUMERIC_PRECISION = 38; static constexpr double MONEY_MIN = -922337203685477.5808; static constexpr double MONEY_MAX = 922337203685477.5807; -// SMALLMONEY range: -214,748.3648 to 214,748.3647 -static constexpr double SMALLMONEY_MIN = -214748.3648; -static constexpr double SMALLMONEY_MAX = 214748.3647; - -// Platform-specific text C type: unixODBC requires all text as wide chars +// Platform-specific text C type: unixODBC requires all text as wide chars on +// Linux/macOS. On Windows the ODBC driver accepts SQL_C_CHAR for ASCII text. +// This matches the Python slow path's behavior (its SQL_C_CHAR constant is +// numerically -8, which is ODBC's SQL_C_WCHAR — a long-standing alias used +// throughout the Python layer). #if defined(__APPLE__) || defined(__linux__) static constexpr SQLSMALLINT PARAM_C_TYPE_TEXT = SQL_C_WCHAR; #else @@ -481,7 +481,8 @@ static py::object build_numeric_data(const py::object& decimal_param); // - bool before int (bool is a subclass of int in Python) // - datetime before date (datetime is a subclass of date) // -// Some types mutate the params list in-place via PyList_SET_ITEM: +// Some types mutate the params list in-place via PyList_SetItem (which +// decrefs the old slot before stealing the new ref): // - time → normalized to "HH:MM:SS.ffffff" string // - Decimal in MONEY range → formatted via __format__("f") // - Decimal (generic) → converted to NumericData struct @@ -589,23 +590,26 @@ std::vector DetectParamTypes(py::list& params) { if (utf16_len > MAX_INLINE_CHAR) { // DAE path: match slow-path types exactly. - // Non-unicode → SQL_VARCHAR + SQL_C_CHAR (encoded via Python codec in DAE loop) - // Unicode → SQL_WVARCHAR + SQL_C_WCHAR (wide-char streaming in DAE loop) + // Non-unicode (ASCII) → SQL_VARCHAR + PARAM_C_TYPE_TEXT + // On Linux/macOS PARAM_C_TYPE_TEXT == SQL_C_WCHAR, matching + // the slow path's SQL_C_CHAR (which is numerically -8 == + // SQL_C_WCHAR — a long-standing alias in the Python layer). + // Unicode → SQL_WVARCHAR + SQL_C_WCHAR (wide-char streaming) info.isDAE = true; info.columnSize = 0; info.dataPtr = py::reinterpret_borrow(py::handle(obj)); info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; - info.paramCType = is_unicode ? SQL_C_WCHAR : SQL_C_CHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; } else { info.columnSize = is_unicode ? utf16_len : length; - info.utf16Len = utf16_len; info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; } info.decimalDigits = 0; - // Check geometry prefixes - if (length >= 5 && kind == PyUnicode_1BYTE_KIND) { + // Check geometry prefixes (only for non-DAE strings; long geometry + // values stay on the DAE path with their already-set types). + if (!info.isDAE && length >= 5 && kind == PyUnicode_1BYTE_KIND) { const char* ascii = (const char*)PyUnicode_1BYTE_DATA(obj); if (strncmp(ascii, "POINT", 5) == 0 || (length >= 10 && strncmp(ascii, "LINESTRING", 10) == 0) || @@ -636,7 +640,7 @@ std::vector DetectParamTypes(py::list& params) { } // --- datetime (must check before date, since datetime is subclass of date) --- - if (PyObject_IsInstance(obj, datetime_type)) { + if (PyObject_IsInstance(obj, datetime_type) == 1) { py::handle h(obj); py::object tzinfo = h.attr("tzinfo"); if (!tzinfo.is_none()) { @@ -654,7 +658,7 @@ std::vector DetectParamTypes(py::list& params) { } // --- date --- - if (PyObject_IsInstance(obj, date_type)) { + if (PyObject_IsInstance(obj, date_type) == 1) { info.paramSQLType = SQL_TYPE_DATE; info.paramCType = SQL_C_TYPE_DATE; info.columnSize = 10; @@ -663,9 +667,9 @@ std::vector DetectParamTypes(py::list& params) { } // --- time (normalized to string for binding) --- - if (PyObject_IsInstance(obj, time_type)) { + if (PyObject_IsInstance(obj, time_type) == 1) { info.paramSQLType = SQL_TYPE_TIME; - info.paramCType = PARAM_C_TYPE_TEXT; + info.paramCType = PARAM_C_TYPE_TEXT; // matches slow path (its SQL_C_CHAR is -8 = SQL_C_WCHAR) info.columnSize = 16; info.decimalDigits = 6; py::handle h(obj); @@ -673,34 +677,28 @@ std::vector DetectParamTypes(py::list& params) { int minute = h.attr("minute").cast(); int second = h.attr("second").cast(); int microsecond = h.attr("microsecond").cast(); + // Always include microseconds (matches Python's isoformat(timespec="microseconds")). char buf[32]; - if (microsecond > 0) { - snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%06d", hour, minute, second, microsecond); - } else { - snprintf(buf, sizeof(buf), "%02d:%02d:%02d", hour, minute, second); - } + snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%06d", hour, minute, second, microsecond); py::str time_str(buf); Py_ssize_t time_len = py::len(time_str); info.columnSize = std::max(info.columnSize, time_len); - info.utf16Len = time_len; - PyList_SET_ITEM(params.ptr(), i, time_str.release().ptr()); + // PyList_SetItem (lowercase) decrefs the old slot before stealing the new + // reference, so this is safe even if `params` is shared with the caller. + PyList_SetItem(params.ptr(), i, time_str.release().ptr()); continue; } // --- Decimal --- - if (PyObject_IsInstance(obj, decimal_type)) { + if (PyObject_IsInstance(obj, decimal_type) == 1) { py::handle h(obj); py::object as_tuple = h.attr("as_tuple")(); py::object exponent_obj = as_tuple.attr("exponent"); + // NaN / Infinity / sNaN: refuse rather than silently writing 0. if (py::isinstance(exponent_obj)) { - info.paramSQLType = SQL_NUMERIC; - info.paramCType = SQL_C_NUMERIC; - info.columnSize = MAX_NUMERIC_PRECISION; - info.decimalDigits = 0; - py::object numeric_data = build_numeric_data(py::reinterpret_borrow(h)); - PyList_SET_ITEM(params.ptr(), i, numeric_data.release().ptr()); - continue; + throw py::value_error( + "Cannot bind non-finite Decimal (NaN/Infinity) as SQL NUMERIC"); } py::tuple digits = as_tuple.attr("digits").cast(); @@ -728,12 +726,11 @@ std::vector DetectParamTypes(py::list& params) { if (dval >= MONEY_MIN && dval <= MONEY_MAX) { py::str formatted = h.attr("__format__")(py::str("f")); info.paramSQLType = SQL_VARCHAR; - info.paramCType = PARAM_C_TYPE_TEXT; + info.paramCType = PARAM_C_TYPE_TEXT; // matches slow path Py_ssize_t fmtLen = py::len(formatted); info.columnSize = fmtLen; - info.utf16Len = fmtLen; info.decimalDigits = 0; - PyList_SET_ITEM(params.ptr(), i, formatted.release().ptr()); + PyList_SetItem(params.ptr(), i, formatted.release().ptr()); continue; } @@ -744,19 +741,19 @@ std::vector DetectParamTypes(py::list& params) { NumericData nd = numeric_data.cast(); info.columnSize = nd.precision; info.decimalDigits = nd.scale; - PyList_SET_ITEM(params.ptr(), i, numeric_data.release().ptr()); + PyList_SetItem(params.ptr(), i, numeric_data.release().ptr()); continue; } // --- UUID --- - if (PyObject_IsInstance(obj, uuid_type)) { + if (PyObject_IsInstance(obj, uuid_type) == 1) { py::handle h(obj); py::bytes bytes_le = h.attr("bytes_le"); info.paramSQLType = SQL_GUID; info.paramCType = SQL_C_GUID; info.columnSize = 16; info.decimalDigits = 0; - PyList_SET_ITEM(params.ptr(), i, bytes_le.release().ptr()); + PyList_SetItem(params.ptr(), i, bytes_le.release().ptr()); continue; } @@ -2505,23 +2502,26 @@ SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, (SQLPOINTER)SQL_CONCUR_READ_ONLY, 0); } - // Match the slow path's encoding-dict contract: keys are "encoding" and "ctype". - // Only honor the user's encoding when their preferred ctype is SQL_C_CHAR; - // otherwise the default ctype is SQL_C_WCHAR and the "encoding" value is - // meant for wide-char paths (e.g. "utf-16le") and would corrupt the - // SQL_C_CHAR DAE/inline path that operates on byte data. + // The encoding-settings dict has the form {"encoding": str, "ctype": int}. + // Note: the Python layer's SQL_C_CHAR constant is numerically -8, the same + // as ODBC's SQL_C_WCHAR. As a result, the only path that genuinely uses + // byte-level character encoding is when the user explicitly opts in via + // setencoding(..., ctype=mssql_python.SQL_CHAR) (which sends ctype=1, the + // real ODBC SQL_CHAR). We default to utf-8 and only honor the dict's + // encoding when ctype == 1 (real ODBC SQL_CHAR). Otherwise the user's + // "encoding" value is meant for the wide-char path and we leave it alone. std::string charEncoding = "utf-8"; if (encoding_settings.contains("ctype") && encoding_settings.contains("encoding")) { int ctype = encoding_settings["ctype"].cast(); - if (ctype == SQL_C_CHAR) { + if (ctype == SQL_C_CHAR /* real ODBC value: 1 */) { charEncoding = encoding_settings["encoding"].cast(); } } - // Shallow-copy the parameter list so DetectParamTypes' in-place - // PyList_SET_ITEM never mutates the caller's list. The cost is one - // PyList_New + N refcount bumps; cheap relative to ODBC binding. - params = py::list(params); + // The cursor.py caller always passes a fresh `list(actual_params)` so this + // function is free to mutate slots in place. Even so, every site below uses + // PyList_SetItem (which decrefs the old slot before stealing the new ref), + // so the function is safe regardless of who owns the list. RETCODE rc; bool already_prepared = is_stmt_prepared[0].cast(); diff --git a/tests/test_023_fast_path_parity.py b/tests/test_023_fast_path_parity.py index 09da6e00..331a4700 100644 --- a/tests/test_023_fast_path_parity.py +++ b/tests/test_023_fast_path_parity.py @@ -1,62 +1,54 @@ """ -Parity tests: assert that fast path (DetectParamTypes in C++) and slow path -(_map_sql_type in Python) produce identical query results for representative -parameter types. +Parity tests: assert that fast path (C++ DetectParamTypes + DDBCSQLExecuteFast) +and slow path (Python _map_sql_type + DDBCSQLExecute) produce identical query +results for representative parameter types. -The fast path runs by default. The slow path is forced by calling setinputsizes() -with a non-None entry, which triggers cursor.execute()'s slow-path branch. +Uses the project's `cursor` fixture from conftest.py so the tests work in any +environment that runs the rest of the suite. """ -import pytest import datetime import decimal +import gc import uuid -from mssql_python import connect - -import os +import weakref -CONN_STR = os.environ.get( - "DB_CONNECTION_STRING", - "Server=localhost;Database=master;Uid=sa;Pwd=Str0ng@Passw0rd123;TrustServerCertificate=yes", -) +import pytest +from mssql_python.constants import ConstantsDDBC as ddbc_sql_const -@pytest.fixture -def conn(): - c = connect(CONN_STR) - yield c - c.close() +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- -def _roundtrip(cursor, value): - """Round-trip a single parameter through SELECT ? and return the result.""" +def _fast_path_roundtrip(cursor, value): + """Default fast path: no setinputsizes.""" cursor.execute("SELECT ?", [value]) return cursor.fetchone()[0] -def _force_slow_path_roundtrip(cursor, value): - """Force slow path via setinputsizes(None for that param) — any non-empty - inputsizes list with a non-None entry triggers the legacy code path.""" - # Empty list with at least one entry that's not None forces slow path. - # Using SQL_VARCHAR(8000) as an opaque "no override" placeholder. - from mssql_python import ddbc_bindings - - # A None entry means "infer", which is fine — the slow path still runs because - # _inputsizes is set (any non-empty list with at least one non-None entry). - # We need at least one non-None entry to flip use_fast_path to False. - cursor.setinputsizes([None]) # Has at least one entry but no override - # Wait — None doesn't trigger slow path. We need a real override. - # Use SQL_VARCHAR which is identity-ish for strings. - cursor.setinputsizes([(1, 0, 0)]) # (sqlType, size, decimal) tuple - cursor.execute("SELECT ?", [value]) - cursor.setinputsizes(None) # Reset - return cursor.fetchone()[0] +def _slow_path_roundtrip(cursor, value, sql_type, column_size): + """Force the slow path by setting an explicit inputsizes entry. The fast + path is gated on `not (self._inputsizes and any(s is not None ...))`, so a + non-None tuple here flips us to the legacy Python type-detection path.""" + cursor.setinputsizes([(sql_type, column_size, 0)]) + try: + cursor.execute("SELECT ?", [value]) + return cursor.fetchone()[0] + finally: + cursor.setinputsizes(None) + + +# --------------------------------------------------------------------------- +# Fast-path coverage: representative type matrix +# --------------------------------------------------------------------------- @pytest.mark.parametrize( "value", [ - # int range detection + # int range detection (TINYINT / SMALLINT / INTEGER / BIGINT) 0, 1, 255, @@ -75,7 +67,7 @@ def _force_slow_path_roundtrip(cursor, value): 0.0, 3.14, -1.5e10, - # str (ASCII inline + DAE + unicode) + # str (ASCII inline) "", "hello", "a" * 100, @@ -85,67 +77,120 @@ def _force_slow_path_roundtrip(cursor, value): b"x" * 100, ], ) -def test_fast_path_roundtrip(conn, value): - """Fast path produces identical results regardless of value type.""" - cur = conn.cursor() - result = _roundtrip(cur, value) - assert ( - result == value - ), f"Roundtrip mismatch for {type(value).__name__} {value!r}: got {result!r}" +def test_fast_path_basic_types(cursor, value): + """Fast path round-trips representative scalar types correctly.""" + result = _fast_path_roundtrip(cursor, value) + assert result == value, ( + f"Fast-path roundtrip mismatch for {type(value).__name__} {value!r}: " f"got {result!r}" + ) + +# --------------------------------------------------------------------------- +# Subclass support — regression for the *_CheckExact bug from PR review +# --------------------------------------------------------------------------- -def test_int_subclass(conn): - """int subclasses must work (regression test for *_CheckExact bug).""" +def test_int_subclass(cursor): class MyInt(int): pass - cur = conn.cursor() - assert _roundtrip(cur, MyInt(42)) == 42 + assert _fast_path_roundtrip(cursor, MyInt(42)) == 42 -def test_str_subclass(conn): - """str subclasses must work.""" - +def test_str_subclass(cursor): class MyStr(str): pass - cur = conn.cursor() - assert _roundtrip(cur, MyStr("hello")) == "hello" - + assert _fast_path_roundtrip(cursor, MyStr("hello")) == "hello" -def test_bytes_subclass(conn): - """bytes subclasses must work.""" +def test_bytes_subclass(cursor): class MyBytes(bytes): pass - cur = conn.cursor() - assert _roundtrip(cur, MyBytes(b"hello")) == b"hello" + assert _fast_path_roundtrip(cursor, MyBytes(b"hello")) == b"hello" -def test_float_subclass(conn): - """float subclasses must work.""" - +def test_float_subclass(cursor): class MyFloat(float): pass - cur = conn.cursor() - assert _roundtrip(cur, MyFloat(3.14)) == 3.14 + assert _fast_path_roundtrip(cursor, MyFloat(3.14)) == 3.14 + +# --------------------------------------------------------------------------- +# Caller-list isolation and refcount safety +# --------------------------------------------------------------------------- -def test_caller_param_list_not_mutated(conn): + +def test_caller_param_list_not_mutated(cursor): """DetectParamTypes must not mutate the caller's parameter list.""" - cur = conn.cursor() params = ["hello", 42, 3.14, datetime.date(2024, 1, 1), uuid.uuid4()] snapshot = list(params) - cur.execute("SELECT ?, ?, ?, ?, ?", params) - cur.fetchone() + cursor.execute("SELECT ?, ?, ?, ?, ?", params) + cursor.fetchone() assert params == snapshot, f"Caller list was mutated: {params} != {snapshot}" -def test_unsupported_type_raises_typeerror(conn): - """Unknown parameter types must raise TypeError, matching slow path.""" - cur = conn.cursor() +def test_no_refcount_leak_on_in_place_replacement(cursor): + """Decimal/UUID/time params get replaced in-place inside DetectParamTypes + via PyList_SetItem. The replaced object must have its reference dropped — + a regression caught in PR review where PyList_SET_ITEM (uppercase, no + decref) leaked one reference per replaced item per execute.""" + + class TrackedDec(decimal.Decimal): + pass + + td = TrackedDec("123.45") + ref = weakref.ref(td) + params = [td] + del td # drop our local strong reference + + cursor.execute("SELECT ?", params) + cursor.fetchone() + del params # drop the list's strong reference + gc.collect() + + assert ref() is None, ( + "Decimal parameter was leaked: PyList_SetItem must decref the old " + "slot before stealing the new reference." + ) + + +# --------------------------------------------------------------------------- +# Error semantics +# --------------------------------------------------------------------------- + + +def test_unsupported_type_raises_typeerror(cursor): + """Fast path must raise TypeError for unknown parameter types — matching + the slow path's `_map_sql_type` final branch.""" with pytest.raises(TypeError): - cur.execute("SELECT ?", [{1, 2, 3}]) # set is not supported + cursor.execute("SELECT ?", [{1, 2, 3}]) # set is not bindable + + +def test_decimal_nan_rejected(cursor): + """Non-finite Decimals must raise rather than silently bind as 0.""" + with pytest.raises(Exception): # ValueError or DataError, not silent zero + cursor.execute("SELECT ?", [decimal.Decimal("NaN")]) + + +# --------------------------------------------------------------------------- +# Fast-vs-slow parity for representative types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value, sql_type, column_size", + [ + ("hello", ddbc_sql_const.SQL_VARCHAR.value, 5), + (42, ddbc_sql_const.SQL_INTEGER.value, 0), + (3.14, ddbc_sql_const.SQL_DOUBLE.value, 0), + (b"data", ddbc_sql_const.SQL_VARBINARY.value, 4), + ], +) +def test_fast_slow_path_parity(cursor, value, sql_type, column_size): + """Same input through both paths produces the same output.""" + fast = _fast_path_roundtrip(cursor, value) + slow = _slow_path_roundtrip(cursor, value, sql_type=sql_type, column_size=column_size) + assert fast == slow, f"Fast/slow path divergence for {value!r}: fast={fast!r} slow={slow!r}"