Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/inc/SOFIE/ROperator_Conv.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,11 @@ public:
size_t gemm_n = outChannels; // output channels
size_t gemm_k = fShapeW[1] * kernelSize; // input channels/group * kernel volume
size_t gemm_m = oDepth * oHeight * oWidth; // output spatial size per channel
if (fAttrGroup > 1) gemm_n /= fAttrGroup; // per-group output channels for grouped conv
size_t colElements = gemm_k * gemm_m; // colRows * colCols
size_t wTotal = ConvertShapeToLength(fShapeW);

// For group conv: per-group output channels and _f offset
// gemm_n stays as total output channels — we divide per group at launch
size_t groupFOffset = gemm_n * gemm_k; // elements of _f per group

std::stringstream out;
Expand Down Expand Up @@ -986,6 +986,7 @@ public:
size_t gemm_n_ = fShapeW[0];
size_t gemm_k_ = fShapeW[1] * kSize_;
size_t gemm_m_ = oDepth_ * oHeight_ * oWidth_;
if (fAttrGroup > 1) gemm_n_ /= fAttrGroup;
auto lda = std::to_string(gemm_m_); // ld for xcol^T (gemm_m×gemm_k col-major)
auto ldb = std::to_string(gemm_k_); // ld for xf^T (gemm_k×gemm_n col-major)
auto ldc = std::to_string(gemm_m_); // ld for y^T (gemm_m×gemm_n col-major)
Expand Down
37 changes: 37 additions & 0 deletions core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@
#include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvWithAsymmetricPadding.ref.hxx"

#include "ConvGroupBatch_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvGroupBatch.ref.hxx"
#include "input_models/references/ConvGroupBatch_input.ref.hxx"

#include "BatchNorm_FromONNX_GPU_ALPAKA.hxx"
#include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx"

Expand Down Expand Up @@ -2200,6 +2204,39 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding)
}
}

// group=2 and batch=4 together; patch test for the gemm_n /= fAttrGroup fix
TEST_F(SofieAlpakaTest, ConvGroupBatch)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

constexpr size_t N = sizeof(ConvGroupBatch_Input::data) / sizeof(float);
auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{N}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < N; ++i) input_ptr[i] = ConvGroupBatch_Input::data[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{N}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

constexpr size_t nOut = sizeof(ConvGroupBatch_ExpectedOutput::output) / sizeof(float);
auto result_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{nOut}));

{
SOFIE_ConvGroupBatch::Session<alpaka::TagGpuCudaRt> session("ConvGroupBatch_FromONNX_GPU_ALPAKA.dat");
auto result = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();
alpaka::memcpy(queue, result_h, result);
alpaka::wait(queue);
}

float* res_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(result_h));
float* correct = ConvGroupBatch_ExpectedOutput::output;
for (size_t i = 0; i < nOut; ++i) {
EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i;
}
}

TEST_F(SofieAlpakaTest, BatchNormalization)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down
Binary file added core/test/input_models/ConvGroupBatch.onnx
Binary file not shown.
5 changes: 5 additions & 0 deletions core/test/input_models/references/ConvGroupBatch.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Auto-generated by ConvGroupBatchModelGenerator.py - DO NOT EDIT
#pragma once
namespace ConvGroupBatch_ExpectedOutput {
float output[] = {-0.36076069f, 0.33097583f, -0.28844666f, -1.04436600f, 0.49192899f, -0.12130111f, 0.39897686f, -1.49153733f, -0.76932675f, 0.48658341f, -0.71374869f, 0.40121353f, -1.12167597f, -0.62728310f, 0.35699576f, -0.47379202f, 0.01621294f, 0.17111908f, -0.74641943f, 0.11123379f, -0.39323545f, 0.49183100f, 0.45035821f, -0.15920340f, -0.45440555f, -0.42033449f, 1.59069741f, 1.25395763f, -0.02720551f, 1.12741852f, 0.57816750f, -0.84116793f, 0.71935523f, 0.19200583f, -0.77897775f, 0.69070578f, 0.69550228f, 1.33932674f, 0.79535830f, -1.28075194f, -0.65100706f, 0.25080451f, -0.68134499f, -1.09874177f, -0.86597049f, 0.00925574f, 1.13835633f, -1.06737292f, -0.24952698f, -0.28619090f, 0.29835856f, -0.21624319f, 0.15673074f, 0.41889524f, 0.55454159f, 0.61738980f, 0.88214344f, 0.09535512f, 0.24054861f, 0.51925898f, 0.24246022f, -0.68892086f, 0.09302564f, -0.42311478f, -0.22747509f, 0.13375665f, -0.33798844f, 0.15409425f, 0.62915587f, -0.97260368f, -0.03625952f, 0.69901109f, -0.23517157f, 0.67972404f, 0.21827847f, 0.11843839f, -0.01277096f, -0.32323205f, -0.03658739f, -0.31435120f, -0.43483058f, -0.27148122f, 0.77499795f, -0.81867105f, -0.99898142f, -0.67271334f, 1.01101458f, 0.78309727f, 0.88418299f, 0.14850061f, -0.46363044f, -0.06768467f, 0.44989449f, 0.52169687f, 0.44447100f, -0.94505298f, -0.56453151f, -0.72638625f, 0.33930093f, -0.03495500f, -0.02821438f, -1.03944421f, -0.78911626f, 0.30141425f, -0.71277291f, 0.87587613f, 0.76956397f, -0.30962253f, 0.66245174f, -1.33300364f, 0.17070784f, -0.87962157f, -1.13295639f, 0.44555551f, -1.04297531f, 0.75829023f, -0.03878885f, -1.10133481f, 0.55174708f, -0.75599957f, -0.51625776f, -0.05342628f, -0.36763048f, 0.65811008f, -0.70035321f, -1.43289232f, -0.17264232f, 0.24939704f, 0.22910510f, -0.15060824f, 1.66544592f, 0.13985462f, -0.15302825f, 1.13994932f, 0.22160804f, 0.24457796f, 0.47076693f, 0.20962349f, -0.05407842f, 1.70852649f, 1.12710464f, -1.44766259f, 1.25966561f, 0.29169840f, 0.92405045f, -0.75349152f, 1.58519661f, 0.49950320f, 0.67342019f, -0.26927525f, -0.67771661f, 0.40583348f, -0.57214516f, -0.60764003f, 0.25248834f, -0.43436378f, 0.62197685f, 0.76375240f, -1.36287653f, 0.84372342f, -0.35162437f, -1.54409766f, 0.51240182f, 0.20225400f, -0.02268404f, 0.30259740f, -0.04716773f, 0.67471498f, 0.71413249f, -0.46658528f, -0.55675507f, -0.51797366f, 0.41532481f, 1.02307224f, 0.39207584f, 0.08235835f, 0.13141029f, 1.01161790f, 0.15537255f, -0.23259896f, -0.64753038f, -0.84689122f, 0.32039008f, 0.86086088f, -0.10391510f, -0.43225932f, -0.34928828f, 0.56980270f, 0.50981444f, -0.21508121f, 0.16830127f, -1.03776038f, -0.56918675f, 0.11749696f, 0.10682096f, 0.31032404f, -0.09735793f, -0.31645539f, -0.08064906f, -0.20384440f, -0.01380761f, -0.07755496f, -0.60068375f, -0.40580261f, 0.91849798f, 0.08390346f, -0.52695745f, -0.42855084f, -1.03313851f, 0.33994401f, -0.36245525f, -0.69688886f, 0.78342688f, -0.97037071f, -0.48480308f, -0.27335811f, -0.97596729f, -0.23528086f, -0.21981074f, -0.63335001f, -0.88829535f, -0.29944593f, -0.19596440f, 1.19875622f, -0.49918032f, -0.79209542f, 0.10065345f, 1.09010482f, 0.01321086f, 2.13910604f, 1.11112499f, 0.72985148f, 1.00838768f, 0.54585910f, -0.15993178f, 0.26541781f, 0.85031760f, 1.73971391f, 1.52358532f, 0.45130971f, 2.48678756f, 0.19343181f, 0.12072347f, -0.44385901f, 0.52359331f, -0.61506343f, 1.24399531f, -1.18155575f, 0.42505923f, -1.23762119f, -0.31627494f, 0.56629950f, -0.52162617f, 1.12653041f, -0.74352539f, 0.52006209f, -0.15141015f, -0.73136127f, 2.61348605f, 1.23570156f, -0.31728053f, 0.77408886f, 0.27774701f, -0.56392157f, 0.45240438f, 0.84599215f, -0.27067959f, 2.16201377f, -0.71181077f, -0.21107928f, 0.28580725f, -0.41565710f, 0.90941215f, -0.52555859f, -0.16738491f, -0.54808658f, -1.08526981f, 0.06592142f, -1.71680820f, 0.08483648f, 0.01231547f, 0.08641901f, -0.34641835f, -1.74067116f, -0.29078999f, -0.16861805f, 0.45419741f, 1.03918445f, 0.26395023f, -0.40493611f, -0.22315894f, 0.71085817f, -0.88118380f, 1.08953798f, 0.25373179f, 0.14846370f, -0.49211681f, 0.22963192f, 0.22093298f, 0.33117050f, -0.37713605f, 0.35486460f, -0.36020797f, -0.33668989f, 0.11798792f, -0.43228078f, -0.57194173f, 0.20756347f, -0.65037400f, -0.15730046f, -0.13377419f, -0.50105321f, -0.28189951f, -0.22371638f, -0.74900097f, -0.02032955f, -0.61617506f, -1.08462250f, -0.56890285f, -0.93203634f, -0.29669374f, -1.03467071f, 0.21930207f, -0.11985322f, -0.08238113f, 0.27048582f, -0.00008567f, -0.82777119f, -0.17688563f, 0.15913069f, -0.80205917f, 0.10852872f, 1.17724657f, 0.41065979f, 0.81592619f, 0.93737853f, 0.34568185f, 1.63516307f, -0.16411117f, 0.52121985f, 1.25948155f, 0.26429343f, -0.58898950f, 1.38607669f, 0.54334569f, 0.25247136f, 0.33251810f, 2.30225015f, 0.15690750f, 1.08385587f, 0.45817482f, 0.39435270f, 0.23890808f, -0.37897390f, 0.04152094f, -0.93999600f, -0.31861931f, -0.26991367f, 1.17474163f, 0.61628640f, 0.66774678f, -0.05754669f, -0.06096449f, -0.57611501f, 0.53568387f, -0.61293066f, -0.34657979f, -0.32216960f, -0.81506526f, -1.00963306f, 0.16005160f, 0.55423015f, -0.18409301f, 0.05382362f, 0.64573503f, -0.30393291f, 0.31414580f, -0.11099648f, 0.88745099f, 0.38012391f, -0.45172656f, -0.73033381f, -0.79347444f, -0.27273351f, 0.03713746f, -0.46960682f, 0.30829662f, -0.86766714f, -0.20282577f, -0.30651718f, 0.38887522f, -0.91033989f, -0.40621519f, -0.28930432f, 0.68293715f, 0.02102731f, -0.87800777f, -0.56976014f, -1.07986236f, -0.12437577f};
} // namespace ConvGroupBatch_ExpectedOutput
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Auto-generated by ConvGroupBatchModelGenerator.py - DO NOT EDIT
#pragma once
namespace ConvGroupBatch_Input {
static float data[400] = {0.49671414f, -0.13826430f, 0.64768857f, 1.52302980f, -0.23415338f, -0.23413695f, 1.57921278f, 0.76743472f, -0.46947438f, 0.54256004f, -0.46341768f, -0.46572974f, 0.24196227f, -1.91328025f, -1.72491789f, -0.56228751f, -1.01283109f, 0.31424734f, -0.90802407f, -1.41230369f, 1.46564877f, -0.22577630f, 0.06752820f, -1.42474818f, -0.54438275f, 0.11092259f, -1.15099359f, 0.37569803f, -0.60063869f, -0.29169375f, -0.60170662f, 1.85227823f, -0.01349723f, -1.05771089f, 0.82254493f, -1.22084367f, 0.20886360f, -1.95967007f, -1.32818604f, 0.19686124f, 0.73846656f, 0.17136829f, -0.11564828f, -0.30110368f, -1.47852194f, -0.71984422f, -0.46063876f, 1.05712223f, 0.34361830f, -1.76304018f, 0.32408398f, -0.38508227f, -0.67692202f, 0.61167628f, 1.03099954f, 0.93128014f, -0.83921754f, -0.30921239f, 0.33126342f, 0.97554511f, -0.47917423f, -0.18565898f, -1.10633492f, -1.19620657f, 0.81252581f, 1.35624003f, -0.07201012f, 1.00353289f, 0.36163601f, -0.64511973f, 0.36139560f, 1.53803658f, -0.03582604f, 1.56464362f, -2.61974502f, 0.82190251f, 0.08704707f, -0.29900736f, 0.09176078f, -1.98756886f, -0.21967189f, 0.35711256f, 1.47789407f, -0.51827019f, -0.80849361f, -0.50175703f, 0.91540211f, 0.32875112f, -0.52976018f, 0.51326746f, 0.09707755f, 0.96864498f, -0.70205307f, -0.32766214f, -0.39210814f, -1.46351492f, 0.29612029f, 0.26105526f, 0.00511346f, -0.23458713f, -1.41537070f, -0.42064533f, -0.34271452f, -0.80227727f, -0.16128571f, 0.40405086f, 1.88618588f, 0.17457782f, 0.25755039f, -0.07444592f, -1.91877127f, -0.02651387f, 0.06023021f, 2.46324205f, -0.19236097f, 0.30154735f, -0.03471177f, -1.16867805f, 1.14282286f, 0.75193304f, 0.79103196f, -0.90938747f, 1.40279436f, -1.40185106f, 0.58685708f, 2.19045568f, -0.99053633f, -0.56629771f, 0.09965137f, -0.50347567f, -1.55066347f, 0.06856298f, -1.06230366f, 0.47359243f, -0.91942424f, 1.54993439f, -0.78325331f, -0.32206151f, 0.81351721f, -1.23086429f, 0.22745994f, 1.30714273f, -1.60748327f, 0.18463387f, 0.25988281f, 0.78182286f, -1.23695076f, -1.32045662f, 0.52194154f, 0.29698467f, 0.25049284f, 0.34644821f, -0.68002474f, 0.23225370f, 0.29307246f, -0.71435142f, 1.86577451f, 0.47383294f, -1.19130349f, 0.65655363f, -0.97468168f, 0.78708458f, 1.15859556f, -0.82068235f, 0.96337610f, 0.41278094f, 0.82206017f, 1.89679301f, -0.24538812f, -0.75373614f, -0.88951445f, -0.81581026f, -0.07710171f, 0.34115198f, 0.27669081f, 0.82718325f, 0.01300189f, 1.45353413f, -0.26465684f, 2.72016907f, 0.62566733f, -0.85715753f, -1.07089245f, 0.48247242f, -0.22346279f, 0.71400052f, 0.47323763f, -0.07282891f, -0.84679371f, -1.51484728f, -0.44651496f, 0.85639882f, 0.21409374f, -1.24573874f, 0.17318092f, 0.38531739f, -0.88385743f, 0.15372510f, 0.05820872f, -1.14297032f, 0.35778737f, 0.56078452f, 1.08305120f, 1.05380201f, -1.37766933f, -0.93782502f, 0.51503527f, 0.51378596f, 0.51504767f, 3.85273147f, 0.57089049f, 1.13556564f, 0.95400178f, 0.65139127f, -0.31526923f, 0.75896925f, -0.77282524f, -0.23681861f, -0.48536354f, 0.08187414f, 2.31465864f, -1.86726522f, 0.68626016f, -1.61271584f, -0.47193187f, 1.08895063f, 0.06428002f, -1.07774472f, -0.71530372f, 0.67959774f, -0.73036665f, 0.21645859f, 0.04557184f, -0.65160036f, 2.14394403f, 0.63391900f, -2.02514267f, 0.18645431f, -0.66178644f, 0.85243332f, -0.79252076f, -0.11473644f, 0.50498730f, 0.86575520f, -1.20029640f, -0.33450124f, -0.47494531f, -0.65332925f, 1.76545429f, 0.40498170f, -1.26088393f, 0.91786194f, 2.12215614f, 1.03246522f, -1.51936996f, -0.48423406f, 1.26691115f, -0.70766944f, 0.44381943f, 0.77463406f, -0.92693049f, -0.05952536f, -3.24126744f, -1.02438760f, -0.25256816f, -1.24778318f, 1.63241136f, -1.43014133f, -0.44004449f, 0.13074058f, 1.44127333f, -1.43586218f, 1.16316378f, 0.01023306f, -0.98150867f, 0.46210349f, 0.19905970f, -0.60021687f, 0.06980208f, -0.38531360f, 0.11351734f, 0.66213065f, 1.58601677f, -1.23781550f, 2.13303328f, -1.95208776f, -0.15178509f, 0.58831722f, 0.28099188f, -0.62269950f, -0.20812225f, -0.49300092f, -0.58936477f, 0.84960210f, 0.35701549f, -0.69290960f, 0.89959985f, 0.30729952f, 0.81286210f, 0.62962884f, -0.82899499f, -0.56018102f, 0.74729359f, 0.61037028f, -0.02090159f, 0.11732738f, 1.27766490f, -0.59157139f, 0.54709738f, -0.20219265f, -0.21768120f, 1.09877682f, 0.82541633f, 0.81350964f, 1.30547881f, 0.02100384f, 0.68195295f, -0.31026676f, 0.32416636f, -0.13014306f, 0.09699596f, 0.59515703f, -0.81822067f, 2.09238720f, -1.00601733f, -1.21418858f, 1.15811086f, 0.79166269f, 0.62411982f, 0.62834549f, -0.01224677f, -0.89725435f, 0.07580456f, -0.67716169f, 0.97511971f, -0.14705738f, -0.82549721f, -0.32138583f, 0.41293144f, -0.56372458f, -0.82222039f, 0.24368721f, 0.24496657f, -0.50694317f, -0.47103831f, 0.23204994f, -1.44808435f, -1.40746379f, -0.71844423f, -0.21344715f, 0.31090757f, 1.47535622f, 0.85765964f, -0.15993853f, -0.01901621f, -1.00252938f, -0.01851314f, -0.28865865f, 0.32271856f, -0.82723093f, 0.51934654f, 1.53273892f, -0.10876015f, 0.40171173f, 0.69014400f, -0.40122047f, 0.22409248f, 0.01259240f, 0.09767610f, -0.77300978f, 0.02451017f, 0.49799830f, 1.45114362f, 0.95927083f, 2.15318251f, -0.76734757f, 0.87232065f, 0.18334201f, 2.18980289f, -0.80829829f, -0.83972186f, -0.59939265f, -2.12389565f, -0.52575505f, -0.75913268f, 0.15039378f, 0.34175599f, 1.87617087f, 0.95042384f, -0.57690364f, -0.89841467f, 0.49191916f, -1.32023323f, 1.83145881f, 1.17944014f, -0.46917567f, -1.71313453f, 1.35387242f, -0.11453985f, 1.23781633f};
} // namespace ConvGroupBatch_Input