Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,10 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
ORT_ENFORCE(it->i < (int64_t)predictions.size());
predictions[onnxruntime::narrow<size_t>(it->i)].score += it->value;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
ORT_ENFORCE(it->i >= 0 && it->i < static_cast<int64_t>(predictions.size()));
const size_t target_id = onnxruntime::narrow<size_t>(it->i);
predictions[target_id].score += it->value;
predictions[target_id].has_score = 1;
}
}

Expand Down Expand Up @@ -393,11 +394,13 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
ORT_ENFORCE(it->i >= 0 && it->i < static_cast<int64_t>(predictions.size()));
const size_t target_id = onnxruntime::narrow<size_t>(it->i);
predictions[target_id].score =
(!predictions[target_id].has_score || it->value < predictions[target_id].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
: predictions[target_id].score;
predictions[target_id].has_score = 1;
}
}

Expand Down Expand Up @@ -449,11 +452,13 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
ORT_ENFORCE(it->i >= 0 && it->i < static_cast<int64_t>(predictions.size()));
const size_t target_id = onnxruntime::narrow<size_t>(it->i);
predictions[target_id].score =
(!predictions[target_id].has_score || it->value > predictions[target_id].score)
? it->value
: predictions[onnxruntime::narrow<size_t>(it->i)].score;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
: predictions[target_id].score;
predictions[target_id].has_score = 1;
}
}

Expand Down
14 changes: 4 additions & 10 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

#pragma once

#include <unordered_map>
#include <stack>
#include <vector>

#include "core/common/inlined_containers.h"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "ml_common.h"
#include "tree_ensemble_helper.h"
#include <unordered_map>
#include <stack>
#include <vector>

namespace onnxruntime {
namespace ml {
Expand Down Expand Up @@ -138,13 +139,6 @@ struct TreeEnsembleAttributesV3 {
"base_values_as_tensor should have 0 or ", n_targets_or_classes, " values.");
}
}

int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ").");
int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids,
") in target_ids or class_ids is greater or equal to the number of targets or classes (",
n_targets_or_classes, ").");
}

std::string aggregate_function;
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function);
post_transform_ = MakeTransform(attributes.post_transform);
n_targets_or_classes_ = attributes.n_targets_or_classes;
ORT_ENFORCE(n_targets_or_classes_ > 0,
"target/class count must be positive, got ", n_targets_or_classes_);
if (!attributes.target_class_ids.empty()) {
const auto [min_target_id, max_target_id] =
std::minmax_element(attributes.target_class_ids.begin(), attributes.target_class_ids.end());
ORT_ENFORCE(*min_target_id >= 0,
"target/class ids cannot have negative values (", *min_target_id, ").");
ORT_ENFORCE(*max_target_id < n_targets_or_classes_,
"At least one value (", *max_target_id,
") in target/class ids is greater or equal to the target/class count (",
n_targets_or_classes_, ").");
}
if (!attributes.base_values_as_tensor.empty()) {
ORT_ENFORCE(attributes.base_values.empty());
base_values_ = attributes.base_values_as_tensor;
Expand Down Expand Up @@ -331,7 +343,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
w.value = attributes.target_class_weights_as_tensor.empty()
? static_cast<ThresholdType>(attributes.target_class_weights[i])
: attributes.target_class_weights_as_tensor[i];
// TreeEnsembleAttributesV3 already made sure that w.i >= 0 && w.i < n_targets_or_classes_.
// TreeEnsembleCommon::Init already made sure that w.i >= 0 && w.i < n_targets_or_classes_.
if (leaf.truenode_or_weight.weight_data.n_weights == 0) {
leaf.truenode_or_weight.weight_data.weight = static_cast<int32_t>(weights_.size());
leaf.value_or_unique_weight = w.value;
Expand Down
73 changes: 73 additions & 0 deletions onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,47 @@ void _multiply_arrays_values(std::vector<T>& data, int64_t val) {
}
}

static void RunInvalidLeafTargetIdsTest(int64_t aggregate_function,
int64_t n_targets,
std::vector<int64_t> leaf_targetids,
const std::string& expected_error) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);

const int64_t post_transform = 0;
std::vector<int64_t> tree_roots = {0};
std::vector<uint8_t> nodes_modes = {0};
std::vector<int64_t> nodes_featureids = {0};
std::vector<double> nodes_splits = {0.0};
std::vector<int64_t> nodes_truenodeids = {0};
std::vector<int64_t> nodes_trueleafs = {1};
std::vector<int64_t> nodes_falsenodeids = {0};
std::vector<int64_t> nodes_falseleafs = {1};
std::vector<double> leaf_weights = {1.0};

auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");

test.AddAttribute("n_targets", n_targets);
test.AddAttribute("aggregate_function", aggregate_function);
test.AddAttribute("post_transform", post_transform);
test.AddAttribute("tree_roots", tree_roots);
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
test.AddAttribute("leaf_targetids", leaf_targetids);
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);

const int64_t output_targets = n_targets > 0 ? n_targets : 0;
test.AddInput<double>("X", {1, 1}, {1.0});
test.AddOutput<double>("Y", {1, output_targets}, std::vector<double>(static_cast<size_t>(output_targets), 0.0));
test.Run(OpTester::ExpectResult::kExpectFailure, expected_error);
}

template <typename T>
void GenTreeAndRunTest(const std::vector<T>& X, const std::vector<T>& Y, const int64_t& aggregate_function, int n_trees = 1) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
Expand Down Expand Up @@ -290,6 +331,38 @@ TEST(MLOpTest, TreeEnsembleLeafOnly) {
test.Run();
}

TEST(MLOpTest, TreeEnsembleMinLeafTargetIdsOutsideBoundary) {
RunInvalidLeafTargetIdsTest(
2,
2,
{5},
"At least one value (5) in target/class ids is greater or equal to the target/class count (2)");
}

TEST(MLOpTest, TreeEnsembleMaxLeafTargetIdsOutsideBoundary) {
RunInvalidLeafTargetIdsTest(
3,
2,
{5},
"At least one value (5) in target/class ids is greater or equal to the target/class count (2)");
}

TEST(MLOpTest, TreeEnsembleNegativeLeafTargetIds) {
RunInvalidLeafTargetIdsTest(
1,
2,
{-1},
"target/class ids cannot have negative values (-1)");
}

TEST(MLOpTest, TreeEnsembleZeroTargets) {
RunInvalidLeafTargetIdsTest(
1,
0,
{0},
"target/class count must be positive, got 0");
}

TEST(MLOpTest, TreeEnsembleLeafLike) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
int64_t n_targets = 1;
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ TEST(MLOpTest, TreeRegressorNegativeTargetIds) {
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values");
test.Run(OpTester::ExpectResult::kExpectFailure, "target/class ids cannot have negative values");
}

TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) {
Expand Down Expand Up @@ -966,7 +966,7 @@ TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) {
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater or equal to the number of targets or classes (1)");
test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target/class ids is greater or equal to the target/class count (1)");
}

TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) {
Expand Down Expand Up @@ -1002,7 +1002,7 @@ TEST(MLOpTest, TreeEnsembleRegressorTargetIdsOutsideBoundary) {
test.AddInput<float>("X", {1, 1}, X);
test.AddOutput<float>("Y", {1, 2}, {0.f, 0.f});

test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target_ids or class_ids is greater or equal to the number of targets or classes (2)");
test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (99) in target/class ids is greater or equal to the target/class count (2)");
}

TEST(MLOpTest, TreeEnsembleRegressorNegativeFeatureId) {
Expand Down
Loading