From cd8a3875fdb72d79bea644f88bafcff222cc80cf Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 26 Jun 2026 12:01:33 -0700 Subject: [PATCH 1/4] Fix TreeEnsemble v5 leaf target validation Validate v5 leaf_targetids while converting to the internal v3 attribute representation so malformed models are rejected during initialization. Add full-range aggregator target index checks for defense in depth and cover invalid v5 leaf targets with provider tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpu/ml/tree_ensemble_aggregator.h | 4 +- .../cpu/ml/tree_ensemble_attribute.h | 12 +++ .../providers/cpu/ml/tree_ensembler_test.cc | 73 +++++++++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index ebdac882bff09..d322179c26331 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -290,7 +290,7 @@ class TreeAggregatorSum : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { - ORT_ENFORCE(it->i < (int64_t)predictions.size()); + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); predictions[onnxruntime::narrow(it->i)].score += it->value; predictions[onnxruntime::narrow(it->i)].has_score = 1; } @@ -393,6 +393,7 @@ class TreeAggregatorMin : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) ? it->value @@ -449,6 +450,7 @@ class TreeAggregatorMax : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) ? it->value diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index 1c89cc0b4723d..f8dcf3b4584ff 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -214,6 +214,18 @@ struct TreeEnsembleAttributesV5 { void convert_to_v3(TreeEnsembleAttributesV3& output) const { // Doing all transformations to get the old format. + ORT_ENFORCE(n_targets > 0, + "n_targets must be positive, got ", n_targets); + if (!leaf_targetids.empty()) { + const int64_t min_target_id = *std::min_element(leaf_targetids.begin(), leaf_targetids.end()); + ORT_ENFORCE(min_target_id >= 0, + "leaf_targetids cannot have negative values (", min_target_id, ")."); + const int64_t max_target_id = *std::max_element(leaf_targetids.begin(), leaf_targetids.end()); + ORT_ENFORCE(max_target_id < n_targets, + "At least one value (", max_target_id, + ") in leaf_targetids is greater or equal to the number of targets (", + n_targets, ")."); + } output.n_targets_or_classes = n_targets; output.aggregate_function = aggregateFunctionToString(); output.post_transform = postTransformToString(); diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc index 1510f3fe3e012..6612050061854 100644 --- a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -93,6 +93,47 @@ void _multiply_arrays_values(std::vector& data, int64_t val) { } } +static void RunInvalidLeafTargetIdsTest(int64_t aggregate_function, + int64_t n_targets, + std::vector leaf_targetids, + const std::string& expected_error) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + + const int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_modes = {0}; + std::vector nodes_featureids = {0}; + std::vector nodes_splits = {0.0}; + std::vector nodes_truenodeids = {0}; + std::vector nodes_trueleafs = {1}; + std::vector nodes_falsenodeids = {0}; + std::vector nodes_falseleafs = {1}; + std::vector leaf_weights = {1.0}; + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + const int64_t output_targets = n_targets > 0 ? n_targets : 0; + test.AddInput("X", {1, 1}, {1.0}); + test.AddOutput("Y", {1, output_targets}, std::vector(static_cast(output_targets), 0.0)); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error); +} + template void GenTreeAndRunTest(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); @@ -290,6 +331,38 @@ TEST(MLOpTest, TreeEnsembleLeafOnly) { test.Run(); } +TEST(MLOpTest, TreeEnsembleMinLeafTargetIdsOutsideBoundary) { + RunInvalidLeafTargetIdsTest( + 2, + 2, + {5}, + "At least one value (5) in leaf_targetids is greater or equal to the number of targets (2)"); +} + +TEST(MLOpTest, TreeEnsembleMaxLeafTargetIdsOutsideBoundary) { + RunInvalidLeafTargetIdsTest( + 3, + 2, + {5}, + "At least one value (5) in leaf_targetids is greater or equal to the number of targets (2)"); +} + +TEST(MLOpTest, TreeEnsembleNegativeLeafTargetIds) { + RunInvalidLeafTargetIdsTest( + 1, + 2, + {-1}, + "leaf_targetids cannot have negative values (-1)"); +} + +TEST(MLOpTest, TreeEnsembleZeroTargets) { + RunInvalidLeafTargetIdsTest( + 1, + 0, + {0}, + "n_targets must be positive, got 0"); +} + TEST(MLOpTest, TreeEnsembleLeafLike) { OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); int64_t n_targets = 1; From 04e403512dd3bef94b9faaa0bf20622de098aa96 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:37:05 -0700 Subject: [PATCH 2/4] Refactor TreeEnsemble target validation Share target id range validation between v3 attributes and v5 conversion. Also store the checked target index once in the N-output aggregators before indexing predictions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpu/ml/tree_ensemble_aggregator.h | 23 ++++---- .../cpu/ml/tree_ensemble_attribute.h | 57 +++++++++++-------- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index d322179c26331..2878147c815b5 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -291,8 +291,9 @@ class TreeAggregatorSum : public TreeAggregatori >= 0 && it->i < static_cast(predictions.size())); - predictions[onnxruntime::narrow(it->i)].score += it->value; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score += it->value; + predictions[target_id].has_score = 1; } } @@ -394,11 +395,12 @@ class TreeAggregatorMin : public TreeAggregatori >= 0 && it->i < static_cast(predictions.size())); - predictions[onnxruntime::narrow(it->i)].score = - (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score = + (!predictions[target_id].has_score || it->value < predictions[target_id].score) ? it->value - : predictions[onnxruntime::narrow(it->i)].score; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + : predictions[target_id].score; + predictions[target_id].has_score = 1; } } @@ -451,11 +453,12 @@ class TreeAggregatorMax : public TreeAggregatori >= 0 && it->i < static_cast(predictions.size())); - predictions[onnxruntime::narrow(it->i)].score = - (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score = + (!predictions[target_id].has_score || it->value > predictions[target_id].score) ? it->value - : predictions[onnxruntime::narrow(it->i)].score; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + : predictions[target_id].score; + predictions[target_id].has_score = 1; } } diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index f8dcf3b4584ff..e329886b5aca6 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -3,14 +3,19 @@ #pragma once +#include +#include +#include +#include +#include + +#include "gsl/gsl" + #include "core/common/inlined_containers.h" #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "ml_common.h" #include "tree_ensemble_helper.h" -#include -#include -#include namespace onnxruntime { namespace ml { @@ -21,6 +26,27 @@ inline bool _isnan_(double x) { return std::isnan(x); } inline bool _isnan_(int64_t) { return false; } inline bool _isnan_(int32_t) { return false; } +inline void ValidateTargetIds(gsl::span target_ids, + int64_t target_count, + std::string_view target_ids_name, + std::string_view target_count_name, + std::string_view target_count_display_name) { + ORT_ENFORCE(target_count > 0, + target_count_name, " must be positive, got ", target_count); + + if (target_ids.empty()) { + return; + } + + const auto [min_target_id, max_target_id] = std::minmax_element(target_ids.begin(), target_ids.end()); + ORT_ENFORCE(*min_target_id >= 0, + target_ids_name, " cannot have negative values (", *min_target_id, ")."); + ORT_ENFORCE(*max_target_id < target_count, + "At least one value (", *max_target_id, ") in ", target_ids_name, + " is greater or equal to ", target_count_display_name, " (", + target_count, ")."); +} + template struct TreeEnsembleAttributesV3 { TreeEnsembleAttributesV3() : n_targets_or_classes(0) {} @@ -71,8 +97,8 @@ struct TreeEnsembleAttributesV3 { target_class_weights = info.GetAttrsOrDefault("target_weights"); } - ORT_ENFORCE(n_targets_or_classes > 0, - "n_targets_or_classes must be positive, got ", n_targets_or_classes); + ValidateTargetIds(target_class_ids, n_targets_or_classes, + "target_ids or class_ids", "n_targets_or_classes", "the number of targets or classes"); ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size(), "nodes_falsenodeids and nodes_featureids must have the same size, got ", nodes_falsenodeids.size(), " and ", nodes_featureids.size()); @@ -138,13 +164,6 @@ struct TreeEnsembleAttributesV3 { "base_values_as_tensor should have 0 or ", n_targets_or_classes, " values."); } } - - int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end()); - ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ")."); - int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end()); - ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids, - ") in target_ids or class_ids is greater or equal to the number of targets or classes (", - n_targets_or_classes, ")."); } std::string aggregate_function; @@ -214,18 +233,8 @@ struct TreeEnsembleAttributesV5 { void convert_to_v3(TreeEnsembleAttributesV3& output) const { // Doing all transformations to get the old format. - ORT_ENFORCE(n_targets > 0, - "n_targets must be positive, got ", n_targets); - if (!leaf_targetids.empty()) { - const int64_t min_target_id = *std::min_element(leaf_targetids.begin(), leaf_targetids.end()); - ORT_ENFORCE(min_target_id >= 0, - "leaf_targetids cannot have negative values (", min_target_id, ")."); - const int64_t max_target_id = *std::max_element(leaf_targetids.begin(), leaf_targetids.end()); - ORT_ENFORCE(max_target_id < n_targets, - "At least one value (", max_target_id, - ") in leaf_targetids is greater or equal to the number of targets (", - n_targets, ")."); - } + ValidateTargetIds(leaf_targetids, n_targets, + "leaf_targetids", "n_targets", "the number of targets"); output.n_targets_or_classes = n_targets; output.aggregate_function = aggregateFunctionToString(); output.post_transform = postTransformToString(); From 8ab1627a5f074df2917237ea0ac594aa44e4955f Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:50:44 -0700 Subject: [PATCH 3/4] Document TreeEnsemble target validation helper Add a brief comment summarizing the target count and target id range invariants checked by ValidateTargetIds. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index e329886b5aca6..dfd21b556438c 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -26,6 +26,7 @@ inline bool _isnan_(double x) { return std::isnan(x); } inline bool _isnan_(int64_t) { return false; } inline bool _isnan_(int32_t) { return false; } +// Target count must be positive, and target ids must be in [0, target_count). inline void ValidateTargetIds(gsl::span target_ids, int64_t target_count, std::string_view target_ids_name, From dde8dc8d5aea91c0bdbbcdc5971fe35f39ef7e1c Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:35:20 -0700 Subject: [PATCH 4/4] Centralize TreeEnsemble target id validation Move target/class id range validation into the shared TreeEnsembleCommon::Init path so all normalized entry points are covered. Use generic target/class id error messages and keep the v3 constructor's early positive-count check for base-value validation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpu/ml/tree_ensemble_attribute.h | 32 ++----------------- .../providers/cpu/ml/tree_ensemble_common.h | 14 +++++++- .../providers/cpu/ml/tree_ensembler_test.cc | 8 ++--- .../providers/cpu/ml/treeregressor_test.cc | 6 ++-- 4 files changed, 22 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index dfd21b556438c..5916fa5fcc1c4 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -3,14 +3,10 @@ #pragma once -#include -#include #include #include #include -#include "gsl/gsl" - #include "core/common/inlined_containers.h" #include "core/common/common.h" #include "core/framework/op_kernel.h" @@ -26,28 +22,6 @@ inline bool _isnan_(double x) { return std::isnan(x); } inline bool _isnan_(int64_t) { return false; } inline bool _isnan_(int32_t) { return false; } -// Target count must be positive, and target ids must be in [0, target_count). -inline void ValidateTargetIds(gsl::span target_ids, - int64_t target_count, - std::string_view target_ids_name, - std::string_view target_count_name, - std::string_view target_count_display_name) { - ORT_ENFORCE(target_count > 0, - target_count_name, " must be positive, got ", target_count); - - if (target_ids.empty()) { - return; - } - - const auto [min_target_id, max_target_id] = std::minmax_element(target_ids.begin(), target_ids.end()); - ORT_ENFORCE(*min_target_id >= 0, - target_ids_name, " cannot have negative values (", *min_target_id, ")."); - ORT_ENFORCE(*max_target_id < target_count, - "At least one value (", *max_target_id, ") in ", target_ids_name, - " is greater or equal to ", target_count_display_name, " (", - target_count, ")."); -} - template struct TreeEnsembleAttributesV3 { TreeEnsembleAttributesV3() : n_targets_or_classes(0) {} @@ -98,8 +72,8 @@ struct TreeEnsembleAttributesV3 { target_class_weights = info.GetAttrsOrDefault("target_weights"); } - ValidateTargetIds(target_class_ids, n_targets_or_classes, - "target_ids or class_ids", "n_targets_or_classes", "the number of targets or classes"); + ORT_ENFORCE(n_targets_or_classes > 0, + "n_targets_or_classes must be positive, got ", n_targets_or_classes); ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size(), "nodes_falsenodeids and nodes_featureids must have the same size, got ", nodes_falsenodeids.size(), " and ", nodes_featureids.size()); @@ -234,8 +208,6 @@ struct TreeEnsembleAttributesV5 { void convert_to_v3(TreeEnsembleAttributesV3& output) const { // Doing all transformations to get the old format. - ValidateTargetIds(leaf_targetids, n_targets, - "leaf_targetids", "n_targets", "the number of targets"); output.n_targets_or_classes = n_targets; output.aggregate_function = aggregateFunctionToString(); output.post_transform = postTransformToString(); diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 7ea9cc1edc478..af03a2d42ee18 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -173,6 +173,18 @@ Status TreeEnsembleCommon::Init( aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function); post_transform_ = MakeTransform(attributes.post_transform); n_targets_or_classes_ = attributes.n_targets_or_classes; + ORT_ENFORCE(n_targets_or_classes_ > 0, + "target/class count must be positive, got ", n_targets_or_classes_); + if (!attributes.target_class_ids.empty()) { + const auto [min_target_id, max_target_id] = + std::minmax_element(attributes.target_class_ids.begin(), attributes.target_class_ids.end()); + ORT_ENFORCE(*min_target_id >= 0, + "target/class ids cannot have negative values (", *min_target_id, ")."); + ORT_ENFORCE(*max_target_id < n_targets_or_classes_, + "At least one value (", *max_target_id, + ") in target/class ids is greater or equal to the target/class count (", + n_targets_or_classes_, ")."); + } if (!attributes.base_values_as_tensor.empty()) { ORT_ENFORCE(attributes.base_values.empty()); base_values_ = attributes.base_values_as_tensor; @@ -331,7 +343,7 @@ Status TreeEnsembleCommon::Init( w.value = attributes.target_class_weights_as_tensor.empty() ? static_cast(attributes.target_class_weights[i]) : attributes.target_class_weights_as_tensor[i]; - // TreeEnsembleAttributesV3 already made sure that w.i >= 0 && w.i < n_targets_or_classes_. + // TreeEnsembleCommon::Init already made sure that w.i >= 0 && w.i < n_targets_or_classes_. if (leaf.truenode_or_weight.weight_data.n_weights == 0) { leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc index 6612050061854..922fab7cd2e43 100644 --- a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -336,7 +336,7 @@ TEST(MLOpTest, TreeEnsembleMinLeafTargetIdsOutsideBoundary) { 2, 2, {5}, - "At least one value (5) in leaf_targetids is greater or equal to the number of targets (2)"); + "At least one value (5) in target/class ids is greater or equal to the target/class count (2)"); } TEST(MLOpTest, TreeEnsembleMaxLeafTargetIdsOutsideBoundary) { @@ -344,7 +344,7 @@ TEST(MLOpTest, TreeEnsembleMaxLeafTargetIdsOutsideBoundary) { 3, 2, {5}, - "At least one value (5) in leaf_targetids is greater or equal to the number of targets (2)"); + "At least one value (5) in target/class ids is greater or equal to the target/class count (2)"); } TEST(MLOpTest, TreeEnsembleNegativeLeafTargetIds) { @@ -352,7 +352,7 @@ TEST(MLOpTest, TreeEnsembleNegativeLeafTargetIds) { 1, 2, {-1}, - "leaf_targetids cannot have negative values (-1)"); + "target/class ids cannot have negative values (-1)"); } TEST(MLOpTest, TreeEnsembleZeroTargets) { @@ -360,7 +360,7 @@ TEST(MLOpTest, TreeEnsembleZeroTargets) { 1, 0, {0}, - "n_targets must be positive, got 0"); + "target/class count must be positive, got 0"); } TEST(MLOpTest, TreeEnsembleLeafLike) { diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 7dbb40556a929..a7892da3d676f 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -924,7 +924,7 @@ TEST(MLOpTest, TreeRegressorNegativeTargetIds) { std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values"); + test.Run(OpTester::ExpectResult::kExpectFailure, "target/class ids cannot have negative values"); } TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { @@ -966,7 +966,7 @@ TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater or equal to the number of targets or classes (1)"); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target/class ids is greater or equal to the target/class count (1)"); } TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) { @@ -1002,7 +1002,7 @@ TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) { test.AddInput("X", {1, 1}, X); test.AddOutput("Y", {1, 2}, {0.f, 0.f}); - test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target_ids or class_ids is greater or equal to the number of targets or classes (2)"); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target/class ids is greater or equal to the target/class count (2)"); } TEST(MLOpTest, TreeEnsembleRegressorNegativeFeatureId) {