diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index ebdac882bff09..2878147c815b5 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -290,9 +290,10 @@ class TreeAggregatorSum : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { - ORT_ENFORCE(it->i < (int64_t)predictions.size()); - predictions[onnxruntime::narrow(it->i)].score += it->value; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score += it->value; + predictions[target_id].has_score = 1; } } @@ -393,11 +394,13 @@ class TreeAggregatorMin : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { - predictions[onnxruntime::narrow(it->i)].score = - (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score = + (!predictions[target_id].has_score || it->value < predictions[target_id].score) ? it->value - : predictions[onnxruntime::narrow(it->i)].score; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + : predictions[target_id].score; + predictions[target_id].has_score = 1; } } @@ -449,11 +452,13 @@ class TreeAggregatorMax : public TreeAggregator> weights) const { auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { - predictions[onnxruntime::narrow(it->i)].score = - (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) + ORT_ENFORCE(it->i >= 0 && it->i < static_cast(predictions.size())); + const size_t target_id = onnxruntime::narrow(it->i); + predictions[target_id].score = + (!predictions[target_id].has_score || it->value > predictions[target_id].score) ? it->value - : predictions[onnxruntime::narrow(it->i)].score; - predictions[onnxruntime::narrow(it->i)].has_score = 1; + : predictions[target_id].score; + predictions[target_id].has_score = 1; } } diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index 1c89cc0b4723d..5916fa5fcc1c4 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -3,14 +3,15 @@ #pragma once +#include +#include +#include + #include "core/common/inlined_containers.h" #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "ml_common.h" #include "tree_ensemble_helper.h" -#include -#include -#include namespace onnxruntime { namespace ml { @@ -138,13 +139,6 @@ struct TreeEnsembleAttributesV3 { "base_values_as_tensor should have 0 or ", n_targets_or_classes, " values."); } } - - int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end()); - ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ")."); - int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end()); - ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids, - ") in target_ids or class_ids is greater or equal to the number of targets or classes (", - n_targets_or_classes, ")."); } std::string aggregate_function; diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 7ea9cc1edc478..af03a2d42ee18 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -173,6 +173,18 @@ Status TreeEnsembleCommon::Init( aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function); post_transform_ = MakeTransform(attributes.post_transform); n_targets_or_classes_ = attributes.n_targets_or_classes; + ORT_ENFORCE(n_targets_or_classes_ > 0, + "target/class count must be positive, got ", n_targets_or_classes_); + if (!attributes.target_class_ids.empty()) { + const auto [min_target_id, max_target_id] = + std::minmax_element(attributes.target_class_ids.begin(), attributes.target_class_ids.end()); + ORT_ENFORCE(*min_target_id >= 0, + "target/class ids cannot have negative values (", *min_target_id, ")."); + ORT_ENFORCE(*max_target_id < n_targets_or_classes_, + "At least one value (", *max_target_id, + ") in target/class ids is greater or equal to the target/class count (", + n_targets_or_classes_, ")."); + } if (!attributes.base_values_as_tensor.empty()) { ORT_ENFORCE(attributes.base_values.empty()); base_values_ = attributes.base_values_as_tensor; @@ -331,7 +343,7 @@ Status TreeEnsembleCommon::Init( w.value = attributes.target_class_weights_as_tensor.empty() ? static_cast(attributes.target_class_weights[i]) : attributes.target_class_weights_as_tensor[i]; - // TreeEnsembleAttributesV3 already made sure that w.i >= 0 && w.i < n_targets_or_classes_. + // TreeEnsembleCommon::Init already made sure that w.i >= 0 && w.i < n_targets_or_classes_. if (leaf.truenode_or_weight.weight_data.n_weights == 0) { leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc index 1510f3fe3e012..922fab7cd2e43 100644 --- a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -93,6 +93,47 @@ void _multiply_arrays_values(std::vector& data, int64_t val) { } } +static void RunInvalidLeafTargetIdsTest(int64_t aggregate_function, + int64_t n_targets, + std::vector leaf_targetids, + const std::string& expected_error) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + + const int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_modes = {0}; + std::vector nodes_featureids = {0}; + std::vector nodes_splits = {0.0}; + std::vector nodes_truenodeids = {0}; + std::vector nodes_trueleafs = {1}; + std::vector nodes_falsenodeids = {0}; + std::vector nodes_falseleafs = {1}; + std::vector leaf_weights = {1.0}; + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + const int64_t output_targets = n_targets > 0 ? n_targets : 0; + test.AddInput("X", {1, 1}, {1.0}); + test.AddOutput("Y", {1, output_targets}, std::vector(static_cast(output_targets), 0.0)); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error); +} + template void GenTreeAndRunTest(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); @@ -290,6 +331,38 @@ TEST(MLOpTest, TreeEnsembleLeafOnly) { test.Run(); } +TEST(MLOpTest, TreeEnsembleMinLeafTargetIdsOutsideBoundary) { + RunInvalidLeafTargetIdsTest( + 2, + 2, + {5}, + "At least one value (5) in target/class ids is greater or equal to the target/class count (2)"); +} + +TEST(MLOpTest, TreeEnsembleMaxLeafTargetIdsOutsideBoundary) { + RunInvalidLeafTargetIdsTest( + 3, + 2, + {5}, + "At least one value (5) in target/class ids is greater or equal to the target/class count (2)"); +} + +TEST(MLOpTest, TreeEnsembleNegativeLeafTargetIds) { + RunInvalidLeafTargetIdsTest( + 1, + 2, + {-1}, + "target/class ids cannot have negative values (-1)"); +} + +TEST(MLOpTest, TreeEnsembleZeroTargets) { + RunInvalidLeafTargetIdsTest( + 1, + 0, + {0}, + "target/class count must be positive, got 0"); +} + TEST(MLOpTest, TreeEnsembleLeafLike) { OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); int64_t n_targets = 1; diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 7dbb40556a929..a7892da3d676f 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -924,7 +924,7 @@ TEST(MLOpTest, TreeRegressorNegativeTargetIds) { std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values"); + test.Run(OpTester::ExpectResult::kExpectFailure, "target/class ids cannot have negative values"); } TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { @@ -966,7 +966,7 @@ TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater or equal to the number of targets or classes (1)"); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target/class ids is greater or equal to the target/class count (1)"); } TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) { @@ -1002,7 +1002,7 @@ TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) { test.AddInput("X", {1, 1}, X); test.AddOutput("Y", {1, 2}, {0.f, 0.f}); - test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target_ids or class_ids is greater or equal to the number of targets or classes (2)"); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target/class ids is greater or equal to the target/class count (2)"); } TEST(MLOpTest, TreeEnsembleRegressorNegativeFeatureId) {