From 977a13d4180acdb62af7d370ff3c75db74430e4d Mon Sep 17 00:00:00 2001 From: Harsh Chauhan Date: Tue, 16 Jun 2026 01:27:05 +0530 Subject: [PATCH 1/2] extended gpu support for selu --- core/inc/SOFIE/ROperator.hxx | 3 ++- core/inc/SOFIE/ROperator_Selu.hxx | 40 +++++++++++++++++++++++++++++++ core/src/RModel_ALPAKA.cxx | 3 ++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/core/inc/SOFIE/ROperator.hxx b/core/inc/SOFIE/ROperator.hxx index c9ce5cd..67f386b 100644 --- a/core/inc/SOFIE/ROperator.hxx +++ b/core/inc/SOFIE/ROperator.hxx @@ -38,7 +38,8 @@ enum class OperatorKind { UNARY_COS=22, UNARY_ABS=23, CLIP=24, - NOT=25 + NOT=25, + SELU=26 }; inline const char* toString(OperatorKind kind) { diff --git a/core/inc/SOFIE/ROperator_Selu.hxx b/core/inc/SOFIE/ROperator_Selu.hxx index 5bec42c..3db298d 100644 --- a/core/inc/SOFIE/ROperator_Selu.hxx +++ b/core/inc/SOFIE/ROperator_Selu.hxx @@ -25,6 +25,7 @@ public: fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){ fInputTensorNames = { fNX }; fOutputTensorNames = { fNY }; + fKind = OperatorKind::SELU; } std::vector TypeInference(std::vector input) override { @@ -59,6 +60,45 @@ public: } std::vector GetStdLibs() override { return { std::string("cmath") };} + + std::string Generate_GPU_Kernel_ALPAKA(std::string /*opName*/) override { + std::string op; + op = "\n//---- SELU_KERNEL_ALPAKA//\n"; + op += "struct SeluKernel {\n"; + op += SP + "template\n"; + op += SP + "ALPAKA_FN_ACC void operator()(TAcc const& acc, T const* __restrict__ data, T* __restrict__ out, std::size_t numElements) const {\n"; + op += SP + SP + "const auto idx = alpaka::getIdx(acc)[0];\n"; + op += SP + SP + "if (idx < numElements) {\n"; + op += SP + SP + SP + "T x = data[idx];\n"; + op += SP + SP + SP + "T inner = T(1.6732632423543772848170429916717) * (exp(x) - T(1));\n"; + op += SP + SP + SP + "out[idx] = T(1.0507009873554804934193349852946) * ((x > T(0) ? x : T(0)) + (inner < T(0) ? inner : T(0)));\n"; + op += SP + SP + "}\n"; + op += SP + "}\n"; + op += "};\n"; + return op; + } + + std::string Generate_GPU_Kernel_Definitions_ALPAKA(std::string /*opName*/) override { + return SP + "SeluKernel seluKernel;\n"; + } + + std::string Generate_GPU_ALPAKA(std::string OpName) override { + OpName = "op_" + OpName; + if (fShape.empty()) { + throw std::runtime_error("SOFIE Selu called to Generate_GPU_ALPAKA without being initialized"); + } + std::stringstream out; + std::string length = ConvertDimShapeToLength(fShape); + out << "\n//------ SELU_GPU_ALPAKA\n"; + out << SP << "auto const elementsPerThread_" << fNX << " = Vec::all(static_cast(1));\n"; + out << SP << "auto const elementsPerGrid_" << fNX << " = Vec::all(Idx{" << length << "});\n"; + out << SP << "auto const workDiv_" << fNX << " = sofie_workdiv(elementsPerGrid_" << fNX << ");\n"; + out << SP << "auto task_" << OpName << " = alpaka::createTaskKernel(workDiv_" << fNX + << ", seluKernel, alpaka::getPtrNative(deviceBuf_" << fNX + << "), alpaka::getPtrNative(deviceBuf_" << fNY << "), static_cast(" << length << "));\n"; + out << SP << "alpaka::enqueue(queue, task_" << OpName << ");\n"; + return out.str(); + } }; }//SOFIE diff --git a/core/src/RModel_ALPAKA.cxx b/core/src/RModel_ALPAKA.cxx index 50c9913..bf7161d 100644 --- a/core/src/RModel_ALPAKA.cxx +++ b/core/src/RModel_ALPAKA.cxx @@ -545,7 +545,8 @@ void RModel::GenerateSessionCode_GPU_ALPAKA() { SOFIE::OperatorKind::UNARY_SIN, SOFIE::OperatorKind::UNARY_COS, SOFIE::OperatorKind::UNARY_ABS, - SOFIE::OperatorKind::NOT + SOFIE::OperatorKind::NOT, + SOFIE::OperatorKind::SELU }; bool OpNeedsBlas = false; From 278e8857348ba6050f9780aeed7aad7021c1e27d Mon Sep 17 00:00:00 2001 From: Harsh Chauhan Date: Tue, 16 Jun 2026 03:03:11 +0530 Subject: [PATCH 2/2] add gtest for selu gpu support --- .../TestCustomModelsFromONNXForAlpakaCuda.cxx | 37 ++++++++++++++++++ core/test/input_models/Selu.onnx | Bin 0 -> 90 bytes .../test/input_models/references/Selu.ref.hxx | 5 +++ 3 files changed, 42 insertions(+) create mode 100644 core/test/input_models/Selu.onnx create mode 100644 core/test/input_models/references/Selu.ref.hxx diff --git a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx index 5ad9383..01b755a 100644 --- a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx +++ b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx @@ -176,6 +176,9 @@ #include "Clip_FromONNX_GPU_ALPAKA.hxx" #include "Not_FromONNX_GPU_ALPAKA.hxx" +#include "Selu_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/Selu.ref.hxx" + #include "GNN_model_FromONNX_GPU_ALPAKA.hxx" #include @@ -3161,3 +3164,37 @@ TEST_F(SofieAlpakaTest, Logic_BitwiseNot) for (std::size_t i = 0; i < N; ++i) EXPECT_EQ(res[i], ref[i]) << " index=" << i; } + +TEST_F(SofieAlpakaTest, Selu) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // input spans negative + positive so the SELU negative branch is exercised + std::vector input({1.0f, -2.0f, 3.0f, 0.5f, -1.0f, 2.0f}); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{input.size()})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{input.size()})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + constexpr size_t nOut = sizeof(Selu_ExpectedOutput::outputs) / sizeof(float); + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{nOut})); + + { + SOFIE_Selu::Session session; + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = Selu_ExpectedOutput::outputs; + for (size_t i = 0; i < nOut; ++i) { + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; + } +} diff --git a/core/test/input_models/Selu.onnx b/core/test/input_models/Selu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ce999694f1c939033beef5471c0eb0a0dda71551 GIT binary patch literal 90 zcmd