diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 05324875..8659b065 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -1452,11 +1452,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state # Getting encoding setting encoding_settings = self._get_encoding_settings() - # Apply timeout if set (non-zero) - logger.debug("execute: Creating parameter type list") - param_info = ddbc_bindings.ParamInfo - parameters_type = [] - # Validate that inputsizes matches parameter count if both are present if parameters and self._inputsizes: if len(self._inputsizes) != len(parameters): @@ -1468,11 +1463,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state Warning, ) - if parameters: - for i, param in enumerate(parameters): - paraminfo = self._create_parameter_types_list(param, param_info, parameters, i) - parameters_type.append(paraminfo) - # Prepare caching: skip SQLPrepare when re-executing the same SQL # with parameters. The HSTMT is reused via _soft_reset_cursor, so the # server-side plan from the previous SQLPrepare is still valid. @@ -1481,30 +1471,54 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self.is_stmt_prepared = [False] effective_use_prepare = use_prepare and not same_sql - if logger.isEnabledFor(logging.DEBUG): - for i, param in enumerate(parameters): - logger.debug( - """Parameter number: %s, Parameter: %s, - Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", - i + 1, - param, - str(type(param)), - parameters_type[i].paramSQLType, - parameters_type[i].paramCType, - parameters_type[i].columnSize, - parameters_type[i].decimalDigits, - parameters_type[i].inputOutputType, - ) - - ret = ddbc_bindings.DDBCSQLExecute( - self.hstmt, - operation, - parameters, - parameters_type, - self.is_stmt_prepared, - effective_use_prepare, - encoding_settings, + # Fast path: when no inputsizes override, do type detection + bind + execute + # entirely in C++. ParamInfo never crosses the pybind11 boundary. + use_fast_path = parameters and not ( + self._inputsizes and any(s is not None for s in self._inputsizes) ) + + if use_fast_path: + ret = ddbc_bindings.DDBCSQLExecuteFast( + self.hstmt, + operation, + parameters, + self.is_stmt_prepared, + effective_use_prepare, + encoding_settings, + ) + else: + # Slow path: Python-side type detection (used when setinputsizes overrides are present) + parameters_type = [] + if parameters: + param_info = ddbc_bindings.ParamInfo + for i, param in enumerate(parameters): + paraminfo = self._create_parameter_types_list(param, param_info, parameters, i) + parameters_type.append(paraminfo) + + if logger.isEnabledFor(logging.DEBUG): + for i, param in enumerate(parameters): + logger.debug( + """Parameter number: %s, Parameter: %s, + Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s""", + i + 1, + param, + str(type(param)), + parameters_type[i].paramSQLType, + parameters_type[i].paramCType, + parameters_type[i].columnSize, + parameters_type[i].decimalDigits, + parameters_type[i].inputOutputType, + ) + + ret = ddbc_bindings.DDBCSQLExecute( + self.hstmt, + operation, + parameters, + parameters_type, + self.is_stmt_prepared, + effective_use_prepare, + encoding_settings, + ) # Check return code try: diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f0a5de75..ae0171ac 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -163,6 +163,7 @@ struct ParamInfo { SQLLEN strLenOrInd = 0; // Required for DAE bool isDAE = false; // Indicates if we need to stream py::object dataPtr; + Py_ssize_t utf16Len = 0; // UTF-16 code unit count for string params }; #ifdef __GNUC__ #pragma GCC diagnostic pop @@ -439,6 +440,384 @@ std::string DescribeChar(unsigned char ch) { } } +// --------------------------------------------------------------------------- +// Constants for DetectParamTypes +// --------------------------------------------------------------------------- + +// Strings longer than this use data-at-execution (DAE) streaming +static constexpr int MAX_INLINE_CHAR = 4000; + +// Binary data longer than this uses DAE streaming (SQL Server max for non-MAX types) +static constexpr int MAX_INLINE_BINARY = 8000; + +// SQL Server maximum numeric precision +static constexpr int MAX_NUMERIC_PRECISION = 38; + +// MONEY range: -922,337,203,685,477.5808 to 922,337,203,685,477.5807 +static constexpr double MONEY_MIN = -922337203685477.5808; +static constexpr double MONEY_MAX = 922337203685477.5807; + +// Platform-specific text C type: unixODBC requires all text as wide chars on +// Linux/macOS. On Windows the ODBC driver accepts SQL_C_CHAR for ASCII text. +// This matches the Python slow path's behavior (its SQL_C_CHAR constant is +// numerically -8, which is ODBC's SQL_C_WCHAR — a long-standing alias used +// throughout the Python layer). +#if defined(__APPLE__) || defined(__linux__) +static constexpr SQLSMALLINT PARAM_C_TYPE_TEXT = SQL_C_WCHAR; +#else +static constexpr SQLSMALLINT PARAM_C_TYPE_TEXT = SQL_C_CHAR; +#endif + +// Forward declare NumericData helper used by decimal path +static py::object build_numeric_data(const py::object& decimal_param); + +// --------------------------------------------------------------------------- +// DetectParamTypes — C++ type detection for the execute() fast path. +// +// Replaces the Python-side _create_parameter_types_list() loop by doing type +// detection entirely in C++ using raw CPython type checks. +// +// ORDERING MATTERS: +// - bool before int (bool is a subclass of int in Python) +// - datetime before date (datetime is a subclass of date) +// +// Some types mutate the params list in-place via PyList_SetItem (which +// decrefs the old slot before stealing the new ref): +// - time → normalized to "HH:MM:SS.ffffff" string +// - Decimal in MONEY range → formatted via __format__("f") +// - Decimal (generic) → converted to NumericData struct +// - UUID → replaced with bytes_le +// This is safe because execute() already copies the caller's param list +// (via list(actual_params)) before reaching this function. +// --------------------------------------------------------------------------- +std::vector DetectParamTypes(py::list& params) { + PythonObjectCache::initialize(); + + const Py_ssize_t n = py::len(params); + std::vector infos(n); + + PyObject* decimal_type = PythonObjectCache::get_decimal_class().ptr(); + PyObject* uuid_type = PythonObjectCache::get_uuid_class().ptr(); + PyObject* datetime_type = PythonObjectCache::get_datetime_class().ptr(); + PyObject* date_type = PythonObjectCache::get_date_class().ptr(); + PyObject* time_type = PythonObjectCache::get_time_class().ptr(); + + for (Py_ssize_t i = 0; i < n; ++i) { + ParamInfo& info = infos[i]; + info.inputOutputType = SQL_PARAM_INPUT; + info.isDAE = false; + + PyObject* obj = PyList_GET_ITEM(params.ptr(), i); + + // --- None --- + if (obj == Py_None) { + info.paramSQLType = SQL_UNKNOWN_TYPE; + info.paramCType = SQL_C_DEFAULT; + info.columnSize = 1; + info.decimalDigits = 0; + continue; + } + + // --- bool (must check before int, since bool is subclass of int) --- + if (PyBool_Check(obj)) { + info.paramSQLType = SQL_BIT; + info.paramCType = SQL_C_BIT; + info.columnSize = 1; + info.decimalDigits = 0; + continue; + } + + // --- int (allow subclasses, but bool was already caught above) --- + if (PyLong_Check(obj)) { + int overflow = 0; + int64_t val = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (overflow == 0 && !PyErr_Occurred()) { + if (val >= 0 && val <= 255) { + info.paramSQLType = SQL_TINYINT; + info.paramCType = SQL_C_TINYINT; + info.columnSize = 3; + } else if (val >= -32768 && val <= 32767) { + info.paramSQLType = SQL_SMALLINT; + info.paramCType = SQL_C_SHORT; + info.columnSize = 5; + } else if (val >= -2147483648LL && val <= 2147483647LL) { + info.paramSQLType = SQL_INTEGER; + info.paramCType = SQL_C_LONG; + info.columnSize = 10; + } else { + info.paramSQLType = SQL_BIGINT; + info.paramCType = SQL_C_SBIGINT; + info.columnSize = 19; + } + } else { + PyErr_Clear(); + info.paramSQLType = SQL_BIGINT; + info.paramCType = SQL_C_SBIGINT; + info.columnSize = 19; + } + info.decimalDigits = 0; + continue; + } + + // --- float (allow subclasses) --- + if (PyFloat_Check(obj)) { + info.paramSQLType = SQL_DOUBLE; + info.paramCType = SQL_C_DOUBLE; + info.columnSize = 15; + info.decimalDigits = 0; + continue; + } + + // --- str (allow subclasses) --- + if (PyUnicode_Check(obj)) { + Py_ssize_t length = PyUnicode_GET_LENGTH(obj); + unsigned int kind = PyUnicode_KIND(obj); + + Py_ssize_t utf16_len; + if (kind <= PyUnicode_2BYTE_KIND) { + utf16_len = length; + } else { + utf16_len = 0; + const Py_UCS4* data = PyUnicode_4BYTE_DATA(obj); + for (Py_ssize_t j = 0; j < length; ++j) { + utf16_len += (data[j] > 0xFFFF) ? 2 : 1; + } + } + + bool is_unicode = (kind > PyUnicode_1BYTE_KIND) || + (PyUnicode_IS_COMPACT_ASCII(obj) == 0 && kind == PyUnicode_1BYTE_KIND + && PyUnicode_MAX_CHAR_VALUE(obj) > 127); + + if (utf16_len > MAX_INLINE_CHAR) { + // DAE path: match slow-path types exactly. + // Non-unicode (ASCII) → SQL_VARCHAR + PARAM_C_TYPE_TEXT + // On Linux/macOS PARAM_C_TYPE_TEXT == SQL_C_WCHAR, matching + // the slow path's SQL_C_CHAR (which is numerically -8 == + // SQL_C_WCHAR — a long-standing alias in the Python layer). + // Unicode → SQL_WVARCHAR + SQL_C_WCHAR (wide-char streaming) + info.isDAE = true; + info.columnSize = 0; + info.dataPtr = py::reinterpret_borrow(py::handle(obj)); + info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; + } else { + info.columnSize = is_unicode ? utf16_len : length; + info.paramSQLType = is_unicode ? SQL_WVARCHAR : SQL_VARCHAR; + info.paramCType = is_unicode ? SQL_C_WCHAR : PARAM_C_TYPE_TEXT; + } + info.decimalDigits = 0; + + // Check geometry prefixes (only for non-DAE strings; long geometry + // values stay on the DAE path with their already-set types). + if (!info.isDAE && length >= 5 && kind == PyUnicode_1BYTE_KIND) { + const char* ascii = (const char*)PyUnicode_1BYTE_DATA(obj); + if (strncmp(ascii, "POINT", 5) == 0 || + (length >= 10 && strncmp(ascii, "LINESTRING", 10) == 0) || + (length >= 7 && strncmp(ascii, "POLYGON", 7) == 0)) { + info.paramSQLType = SQL_WVARCHAR; + info.paramCType = SQL_C_WCHAR; + info.columnSize = length; + } + } + continue; + } + + // --- bytes / bytearray (allow subclasses) --- + if (PyBytes_Check(obj) || PyByteArray_Check(obj)) { + Py_ssize_t length = + PyBytes_Check(obj) ? PyBytes_Size(obj) : PyByteArray_Size(obj); + info.paramSQLType = SQL_VARBINARY; + info.paramCType = SQL_C_BINARY; + info.decimalDigits = 0; + if (length > MAX_INLINE_BINARY) { + info.isDAE = true; + info.columnSize = 0; + info.dataPtr = py::reinterpret_borrow(py::handle(obj)); + } else { + info.columnSize = std::max(length, 1); + } + continue; + } + + // --- datetime (must check before date, since datetime is subclass of date) --- + if (PyObject_IsInstance(obj, datetime_type) == 1) { + py::handle h(obj); + py::object tzinfo = h.attr("tzinfo"); + if (!tzinfo.is_none()) { + info.paramSQLType = SQL_SS_TIMESTAMPOFFSET; + info.paramCType = SQL_C_SS_TIMESTAMPOFFSET; + info.columnSize = 34; + info.decimalDigits = 7; + } else { + info.paramSQLType = SQL_TYPE_TIMESTAMP; + info.paramCType = SQL_C_TYPE_TIMESTAMP; + info.columnSize = 26; + info.decimalDigits = 6; + } + continue; + } + + // --- date --- + if (PyObject_IsInstance(obj, date_type) == 1) { + info.paramSQLType = SQL_TYPE_DATE; + info.paramCType = SQL_C_TYPE_DATE; + info.columnSize = 10; + info.decimalDigits = 0; + continue; + } + + // --- time (normalized to string for binding) --- + if (PyObject_IsInstance(obj, time_type) == 1) { + info.paramSQLType = SQL_TYPE_TIME; + info.paramCType = PARAM_C_TYPE_TEXT; // matches slow path (its SQL_C_CHAR is -8 = SQL_C_WCHAR) + info.columnSize = 16; + info.decimalDigits = 6; + py::handle h(obj); + int hour = h.attr("hour").cast(); + int minute = h.attr("minute").cast(); + int second = h.attr("second").cast(); + int microsecond = h.attr("microsecond").cast(); + // Always include microseconds (matches Python's isoformat(timespec="microseconds")). + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%06d", hour, minute, second, microsecond); + py::str time_str(buf); + Py_ssize_t time_len = py::len(time_str); + info.columnSize = std::max(info.columnSize, time_len); + // PyList_SetItem (lowercase) decrefs the old slot before stealing the new + // reference, so this is safe even if `params` is shared with the caller. + PyList_SetItem(params.ptr(), i, time_str.release().ptr()); + continue; + } + + // --- Decimal --- + if (PyObject_IsInstance(obj, decimal_type) == 1) { + py::handle h(obj); + py::object as_tuple = h.attr("as_tuple")(); + py::object exponent_obj = as_tuple.attr("exponent"); + + // NaN / Infinity / sNaN: refuse rather than silently writing 0. + if (py::isinstance(exponent_obj)) { + throw py::value_error( + "Cannot bind non-finite Decimal (NaN/Infinity) as SQL NUMERIC"); + } + + py::tuple digits = as_tuple.attr("digits").cast(); + int num_digits = static_cast(py::len(digits)); + int exponent = exponent_obj.cast(); + int precision; + if (exponent >= 0) + precision = num_digits + exponent; + else if ((-exponent) <= num_digits) + precision = num_digits; + else + precision = -exponent; + + if (precision > MAX_NUMERIC_PRECISION) { + throw py::value_error( + "Precision of the numeric value is too high. " + "The maximum precision supported by SQL Server is " + + std::to_string(MAX_NUMERIC_PRECISION) + ", but got " + + std::to_string(precision) + "."); + } + + // SMALLMONEY/MONEY range — bind as formatted VARCHAR string + // to match SQL Server's fixed-point money semantics. + double dval = h.attr("__float__")().cast(); + if (dval >= MONEY_MIN && dval <= MONEY_MAX) { + py::str formatted = h.attr("__format__")(py::str("f")); + info.paramSQLType = SQL_VARCHAR; + info.paramCType = PARAM_C_TYPE_TEXT; // matches slow path + Py_ssize_t fmtLen = py::len(formatted); + info.columnSize = fmtLen; + info.decimalDigits = 0; + PyList_SetItem(params.ptr(), i, formatted.release().ptr()); + continue; + } + + // Generic numeric binding via SQL_NUMERIC_STRUCT + info.paramSQLType = SQL_NUMERIC; + info.paramCType = SQL_C_NUMERIC; + py::object numeric_data = build_numeric_data(py::reinterpret_borrow(h)); + NumericData nd = numeric_data.cast(); + info.columnSize = nd.precision; + info.decimalDigits = nd.scale; + PyList_SetItem(params.ptr(), i, numeric_data.release().ptr()); + continue; + } + + // --- UUID --- + if (PyObject_IsInstance(obj, uuid_type) == 1) { + py::handle h(obj); + py::bytes bytes_le = h.attr("bytes_le"); + info.paramSQLType = SQL_GUID; + info.paramCType = SQL_C_GUID; + info.columnSize = 16; + info.decimalDigits = 0; + PyList_SetItem(params.ptr(), i, bytes_le.release().ptr()); + continue; + } + + // --- Unknown type: raise TypeError (matches Python _map_sql_type) --- + throw py::type_error( + "Unsupported parameter type: The driver cannot safely convert it to a SQL type."); + } + + return infos; +} + +// Helper: build SQL_NUMERIC_STRUCT from Python Decimal +static py::object build_numeric_data(const py::object& decimal_param) { + py::object as_tuple = decimal_param.attr("as_tuple")(); + py::tuple digits = as_tuple.attr("digits").cast(); + int sign_val = as_tuple.attr("sign").cast(); + py::object exponent_obj = as_tuple.attr("exponent"); + + int exponent = 0; + if (py::isinstance(exponent_obj)) { + exponent = exponent_obj.cast(); + } + + int num_digits = static_cast(py::len(digits)); + int precision, scale; + if (exponent >= 0) { + precision = num_digits + exponent; + scale = 0; + } else { + scale = -exponent; + precision = std::max(num_digits, scale); + } + precision = std::max(1, std::min(precision, MAX_NUMERIC_PRECISION)); + scale = std::min(scale, precision); + + py::object py_zero = py::int_(0); + py::object int_val = py_zero; + for (auto d : digits) { + int_val = int_val * py::int_(10) + d.cast(); + } + if (exponent > 0) { + py::object multiplier = py::int_(1); + for (int j = 0; j < exponent; ++j) + multiplier = multiplier * py::int_(10); + int_val = int_val * multiplier; + } + + py::object abs_val = int_val.attr("__abs__")(); + py::bytes val_bytes = abs_val.attr("to_bytes")(py::int_(16), py::str("little")); + std::string val_str = val_bytes.cast(); + + NumericData nd; + nd.precision = static_cast(precision); + nd.scale = static_cast(scale); + nd.sign = (sign_val == 0) ? 1 : 0; + std::memset(&nd.val[0], 0, SQL_MAX_NUMERIC_LEN); + size_t copy_len = std::min(val_str.size(), static_cast(SQL_MAX_NUMERIC_LEN)); + if (copy_len > 0 && val_str.data() != nullptr) { + std::memcpy(&nd.val[0], val_str.data(), copy_len); + } + + return py::cast(nd); +} + // Given a list of parameters and their ParamInfo, calls SQLBindParameter on // each of them with appropriate arguments SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, @@ -2094,6 +2473,190 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } } +// --------------------------------------------------------------------------- +// SQLExecuteFast — single C++ pipeline: DetectParamTypes → BindParameters → SQLExecute +// No ParamInfo objects cross the pybind11 boundary. +// +// Always uses SQLPrepare (not ExecDirect) because parameterized queries +// benefit from prepared plan reuse, and the fast path is only invoked +// when parameters are present. The use_prepare flag from the caller is +// acknowledged but overridden — this is a perf-only code path. +// --------------------------------------------------------------------------- +SQLRETURN SQLExecuteFast_wrap(const SqlHandlePtr statementHandle, + const std::wstring& query, + py::list params, + py::list is_stmt_prepared, + bool /*use_prepare*/, + const py::dict& encoding_settings) { + if (!statementHandle || !statementHandle->get()) { + return SQL_INVALID_HANDLE; + } + + SQLHANDLE hStmt = statementHandle->get(); + + // Configure forward-only / read-only cursor (matches slow path semantics). + if (SQLSetStmtAttr_ptr) { + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, 0); + } + + // The encoding-settings dict has the form {"encoding": str, "ctype": int}. + // Note: the Python layer's SQL_C_CHAR constant is numerically -8, the same + // as ODBC's SQL_C_WCHAR. As a result, the only path that genuinely uses + // byte-level character encoding is when the user explicitly opts in via + // setencoding(..., ctype=mssql_python.SQL_CHAR) (which sends ctype=1, the + // real ODBC SQL_CHAR). We default to utf-8 and only honor the dict's + // encoding when ctype == 1 (real ODBC SQL_CHAR). Otherwise the user's + // "encoding" value is meant for the wide-char path and we leave it alone. + std::string charEncoding = "utf-8"; + if (encoding_settings.contains("ctype") && encoding_settings.contains("encoding")) { + int ctype = encoding_settings["ctype"].cast(); + if (ctype == SQL_C_CHAR /* real ODBC value: 1 */) { + charEncoding = encoding_settings["encoding"].cast(); + } + } + + // The cursor.py caller always passes a fresh `list(actual_params)` so this + // function is free to mutate slots in place. Even so, every site below uses + // PyList_SetItem (which decrefs the old slot before stealing the new ref), + // so the function is safe regardless of who owns the list. + + RETCODE rc; + bool already_prepared = is_stmt_prepared[0].cast(); + + // Prepare if needed (fast path always uses prepare for parameterized queries) + if (!already_prepared) { +#if defined(__APPLE__) || defined(__linux__) + std::vector queryBuffer = WStringToSQLWCHAR(query); + SQLWCHAR* queryPtr = queryBuffer.data(); +#else + SQLWCHAR* queryPtr = const_cast(query.c_str()); +#endif + { + py::gil_scoped_release release; + rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); + } + if (!SQL_SUCCEEDED(rc)) return rc; + is_stmt_prepared[0] = py::bool_(true); + } + + // DetectParamTypes + BindParameters in one shot — ParamInfo stays in C++ + std::vector paramInfos = DetectParamTypes(params); + std::vector> paramBuffers; + rc = BindParameters(hStmt, params, paramInfos, paramBuffers, charEncoding); + if (!SQL_SUCCEEDED(rc)) return rc; + + { + py::gil_scoped_release release; + rc = SQLExecute_ptr(hStmt); + } + + // DAE (Data-At-Execution) loop: when BindParameters marks a param as DAE + // (large str/bytes/binary), SQLExecute returns SQL_NEED_DATA. We must + // stream the data via SQLParamData/SQLPutData before execution completes. + // GIL is released around each ODBC call to match slow-path concurrency. + if (rc == SQL_NEED_DATA) { + SQLPOINTER paramToken = nullptr; + while (true) { + { + py::gil_scoped_release release; + rc = SQLParamData_ptr(hStmt, ¶mToken); + } + if (rc != SQL_NEED_DATA) break; + + const ParamInfo* matchedInfo = nullptr; + for (auto& info : paramInfos) { + if (reinterpret_cast(const_cast(&info)) == paramToken) { + matchedInfo = &info; + break; + } + } + if (!matchedInfo) { + ThrowStdException("SQLExecuteFast: unrecognized paramToken from SQLParamData"); + } + const py::object& pyObj = matchedInfo->dataPtr; + if (pyObj.is_none()) { + py::gil_scoped_release release; + SQLPutData_ptr(hStmt, nullptr, 0); + continue; + } + + if (py::isinstance(pyObj)) { + if (matchedInfo->paramCType == SQL_C_WCHAR) { + std::wstring wstr = pyObj.cast(); + const SQLWCHAR* dataPtr = nullptr; + size_t totalChars = 0; +#if defined(__APPLE__) || defined(__linux__) + std::vector sqlwStr = WStringToSQLWCHAR(wstr); + totalChars = sqlwStr.size() - 1; + dataPtr = sqlwStr.data(); +#else + dataPtr = wstr.c_str(); + totalChars = wstr.size(); +#endif + size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR); + for (size_t offset = 0; offset < totalChars; offset += chunkChars) { + size_t len = std::min(chunkChars, totalChars - offset); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len * sizeof(SQLWCHAR))); + } + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else if (matchedInfo->paramCType == SQL_C_CHAR) { + std::string encodedStr; + py::object encoded = pyObj.attr("encode")(charEncoding, "strict"); + encodedStr = encoded.cast(); + const char* dataPtr = encodedStr.data(); + size_t totalBytes = encodedStr.size(); + for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { + size_t len = std::min(static_cast(DAE_CHUNK_SIZE), + totalBytes - offset); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + } + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else { + ThrowStdException("SQLExecuteFast: unsupported C type for str in DAE"); + } + } else if (py::isinstance(pyObj) || + py::isinstance(pyObj)) { + py::bytes b = pyObj.cast(); + std::string s = b; + const char* dataPtr = s.data(); + size_t totalBytes = s.size(); + for (size_t offset = 0; offset < totalBytes; offset += DAE_CHUNK_SIZE) { + size_t len = std::min(static_cast(DAE_CHUNK_SIZE), + totalBytes - offset); + { + py::gil_scoped_release release; + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), + static_cast(len)); + } + if (!SQL_SUCCEEDED(rc)) return rc; + } + } else { + ThrowStdException("SQLExecuteFast: DAE only supported for str or bytes"); + } + } + if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) return rc; + } + + if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) return rc; + + // Preserve the execute return code (e.g. SQL_SUCCESS_WITH_INFO) — don't + // let the SQLFreeStmt return value clobber what the caller needs to see. + SQLRETURN exec_rc = rc; + SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); + return exec_rc; +} + SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, const std::vector& paramInfos, size_t paramSetSize, std::vector>& paramBuffers, @@ -5915,6 +6478,10 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements", py::arg("statementHandle"), py::arg("query"), py::arg("params"), py::arg("paramInfos"), py::arg("isStmtPrepared"), py::arg("usePrepare"), py::arg("encodingSettings")); + m.def("DDBCSQLExecuteFast", &SQLExecuteFast_wrap, + "Fast path: DetectParamTypes + BindParameters + SQLExecute all in C++", + py::arg("statementHandle"), py::arg("query"), py::arg("params"), + py::arg("isStmtPrepared"), py::arg("usePrepare"), py::arg("encodingSettings")); m.def("SQLExecuteMany", &SQLExecuteMany_wrap, "Execute statement with multiple parameter sets", py::arg("statementHandle"), py::arg("query"), py::arg("columnwise_params"), py::arg("paramInfos"), py::arg("paramSetSize"), py::arg("encodingSettings")); diff --git a/tests/test_023_fast_path_parity.py b/tests/test_023_fast_path_parity.py new file mode 100644 index 00000000..331a4700 --- /dev/null +++ b/tests/test_023_fast_path_parity.py @@ -0,0 +1,196 @@ +""" +Parity tests: assert that fast path (C++ DetectParamTypes + DDBCSQLExecuteFast) +and slow path (Python _map_sql_type + DDBCSQLExecute) produce identical query +results for representative parameter types. + +Uses the project's `cursor` fixture from conftest.py so the tests work in any +environment that runs the rest of the suite. +""" + +import datetime +import decimal +import gc +import uuid +import weakref + +import pytest + +from mssql_python.constants import ConstantsDDBC as ddbc_sql_const + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fast_path_roundtrip(cursor, value): + """Default fast path: no setinputsizes.""" + cursor.execute("SELECT ?", [value]) + return cursor.fetchone()[0] + + +def _slow_path_roundtrip(cursor, value, sql_type, column_size): + """Force the slow path by setting an explicit inputsizes entry. The fast + path is gated on `not (self._inputsizes and any(s is not None ...))`, so a + non-None tuple here flips us to the legacy Python type-detection path.""" + cursor.setinputsizes([(sql_type, column_size, 0)]) + try: + cursor.execute("SELECT ?", [value]) + return cursor.fetchone()[0] + finally: + cursor.setinputsizes(None) + + +# --------------------------------------------------------------------------- +# Fast-path coverage: representative type matrix +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value", + [ + # int range detection (TINYINT / SMALLINT / INTEGER / BIGINT) + 0, + 1, + 255, + 256, + 32767, + 32768, + 2147483647, + 2147483648, + -1, + -32768, + -2147483648, + # bool + True, + False, + # float + 0.0, + 3.14, + -1.5e10, + # str (ASCII inline) + "", + "hello", + "a" * 100, + # bytes + b"", + b"\x00\x01\x02", + b"x" * 100, + ], +) +def test_fast_path_basic_types(cursor, value): + """Fast path round-trips representative scalar types correctly.""" + result = _fast_path_roundtrip(cursor, value) + assert result == value, ( + f"Fast-path roundtrip mismatch for {type(value).__name__} {value!r}: " f"got {result!r}" + ) + + +# --------------------------------------------------------------------------- +# Subclass support — regression for the *_CheckExact bug from PR review +# --------------------------------------------------------------------------- + + +def test_int_subclass(cursor): + class MyInt(int): + pass + + assert _fast_path_roundtrip(cursor, MyInt(42)) == 42 + + +def test_str_subclass(cursor): + class MyStr(str): + pass + + assert _fast_path_roundtrip(cursor, MyStr("hello")) == "hello" + + +def test_bytes_subclass(cursor): + class MyBytes(bytes): + pass + + assert _fast_path_roundtrip(cursor, MyBytes(b"hello")) == b"hello" + + +def test_float_subclass(cursor): + class MyFloat(float): + pass + + assert _fast_path_roundtrip(cursor, MyFloat(3.14)) == 3.14 + + +# --------------------------------------------------------------------------- +# Caller-list isolation and refcount safety +# --------------------------------------------------------------------------- + + +def test_caller_param_list_not_mutated(cursor): + """DetectParamTypes must not mutate the caller's parameter list.""" + params = ["hello", 42, 3.14, datetime.date(2024, 1, 1), uuid.uuid4()] + snapshot = list(params) + cursor.execute("SELECT ?, ?, ?, ?, ?", params) + cursor.fetchone() + assert params == snapshot, f"Caller list was mutated: {params} != {snapshot}" + + +def test_no_refcount_leak_on_in_place_replacement(cursor): + """Decimal/UUID/time params get replaced in-place inside DetectParamTypes + via PyList_SetItem. The replaced object must have its reference dropped — + a regression caught in PR review where PyList_SET_ITEM (uppercase, no + decref) leaked one reference per replaced item per execute.""" + + class TrackedDec(decimal.Decimal): + pass + + td = TrackedDec("123.45") + ref = weakref.ref(td) + params = [td] + del td # drop our local strong reference + + cursor.execute("SELECT ?", params) + cursor.fetchone() + del params # drop the list's strong reference + gc.collect() + + assert ref() is None, ( + "Decimal parameter was leaked: PyList_SetItem must decref the old " + "slot before stealing the new reference." + ) + + +# --------------------------------------------------------------------------- +# Error semantics +# --------------------------------------------------------------------------- + + +def test_unsupported_type_raises_typeerror(cursor): + """Fast path must raise TypeError for unknown parameter types — matching + the slow path's `_map_sql_type` final branch.""" + with pytest.raises(TypeError): + cursor.execute("SELECT ?", [{1, 2, 3}]) # set is not bindable + + +def test_decimal_nan_rejected(cursor): + """Non-finite Decimals must raise rather than silently bind as 0.""" + with pytest.raises(Exception): # ValueError or DataError, not silent zero + cursor.execute("SELECT ?", [decimal.Decimal("NaN")]) + + +# --------------------------------------------------------------------------- +# Fast-vs-slow parity for representative types +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value, sql_type, column_size", + [ + ("hello", ddbc_sql_const.SQL_VARCHAR.value, 5), + (42, ddbc_sql_const.SQL_INTEGER.value, 0), + (3.14, ddbc_sql_const.SQL_DOUBLE.value, 0), + (b"data", ddbc_sql_const.SQL_VARBINARY.value, 4), + ], +) +def test_fast_slow_path_parity(cursor, value, sql_type, column_size): + """Same input through both paths produces the same output.""" + fast = _fast_path_roundtrip(cursor, value) + slow = _slow_path_roundtrip(cursor, value, sql_type=sql_type, column_size=column_size) + assert fast == slow, f"Fast/slow path divergence for {value!r}: fast={fast!r} slow={slow!r}"