diff --git a/core/inc/SOFIE/ROperator_Conv.hxx b/core/inc/SOFIE/ROperator_Conv.hxx index 835a0ff..6711bf0 100644 --- a/core/inc/SOFIE/ROperator_Conv.hxx +++ b/core/inc/SOFIE/ROperator_Conv.hxx @@ -836,11 +836,11 @@ public: size_t gemm_n = outChannels; // output channels size_t gemm_k = fShapeW[1] * kernelSize; // input channels/group * kernel volume size_t gemm_m = oDepth * oHeight * oWidth; // output spatial size per channel + if (fAttrGroup > 1) gemm_n /= fAttrGroup; // per-group output channels for grouped conv size_t colElements = gemm_k * gemm_m; // colRows * colCols size_t wTotal = ConvertShapeToLength(fShapeW); // For group conv: per-group output channels and _f offset - // gemm_n stays as total output channels — we divide per group at launch size_t groupFOffset = gemm_n * gemm_k; // elements of _f per group std::stringstream out; @@ -986,6 +986,7 @@ public: size_t gemm_n_ = fShapeW[0]; size_t gemm_k_ = fShapeW[1] * kSize_; size_t gemm_m_ = oDepth_ * oHeight_ * oWidth_; + if (fAttrGroup > 1) gemm_n_ /= fAttrGroup; auto lda = std::to_string(gemm_m_); // ld for xcol^T (gemm_m×gemm_k col-major) auto ldb = std::to_string(gemm_k_); // ld for xf^T (gemm_k×gemm_n col-major) auto ldc = std::to_string(gemm_m_); // ld for y^T (gemm_m×gemm_n col-major) diff --git a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx index e415cce..7963095 100644 --- a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx +++ b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx @@ -115,6 +115,10 @@ #include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx" #include "input_models/references/ConvWithAsymmetricPadding.ref.hxx" +#include "ConvGroupBatch_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/ConvGroupBatch.ref.hxx" +#include "input_models/references/ConvGroupBatch_input.ref.hxx" + #include "BatchNorm_FromONNX_GPU_ALPAKA.hxx" #include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx" @@ -2200,6 +2204,39 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding) } } +// group=2 and batch=4 together; patch test for the gemm_n /= fAttrGroup fix +TEST_F(SofieAlpakaTest, ConvGroupBatch) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + constexpr size_t N = sizeof(ConvGroupBatch_Input::data) / sizeof(float); + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{N})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = ConvGroupBatch_Input::data[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{N})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + constexpr size_t nOut = sizeof(ConvGroupBatch_ExpectedOutput::output) / sizeof(float); + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{nOut})); + + { + SOFIE_ConvGroupBatch::Session session("ConvGroupBatch_FromONNX_GPU_ALPAKA.dat"); + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = ConvGroupBatch_ExpectedOutput::output; + for (size_t i = 0; i < nOut; ++i) { + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; + } +} + TEST_F(SofieAlpakaTest, BatchNormalization) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; diff --git a/core/test/input_models/ConvGroupBatch.onnx b/core/test/input_models/ConvGroupBatch.onnx new file mode 100644 index 0000000..0f58d66 Binary files /dev/null and b/core/test/input_models/ConvGroupBatch.onnx differ diff --git a/core/test/input_models/references/ConvGroupBatch.ref.hxx b/core/test/input_models/references/ConvGroupBatch.ref.hxx new file mode 100644 index 0000000..abe7176 --- /dev/null +++ b/core/test/input_models/references/ConvGroupBatch.ref.hxx @@ -0,0 +1,5 @@ +// Auto-generated by ConvGroupBatchModelGenerator.py - DO NOT EDIT +#pragma once +namespace ConvGroupBatch_ExpectedOutput { + float output[] = {-0.36076069f, 0.33097583f, -0.28844666f, -1.04436600f, 0.49192899f, -0.12130111f, 0.39897686f, -1.49153733f, -0.76932675f, 0.48658341f, -0.71374869f, 0.40121353f, -1.12167597f, -0.62728310f, 0.35699576f, -0.47379202f, 0.01621294f, 0.17111908f, -0.74641943f, 0.11123379f, -0.39323545f, 0.49183100f, 0.45035821f, -0.15920340f, -0.45440555f, -0.42033449f, 1.59069741f, 1.25395763f, -0.02720551f, 1.12741852f, 0.57816750f, -0.84116793f, 0.71935523f, 0.19200583f, -0.77897775f, 0.69070578f, 0.69550228f, 1.33932674f, 0.79535830f, -1.28075194f, -0.65100706f, 0.25080451f, -0.68134499f, -1.09874177f, -0.86597049f, 0.00925574f, 1.13835633f, -1.06737292f, -0.24952698f, -0.28619090f, 0.29835856f, -0.21624319f, 0.15673074f, 0.41889524f, 0.55454159f, 0.61738980f, 0.88214344f, 0.09535512f, 0.24054861f, 0.51925898f, 0.24246022f, -0.68892086f, 0.09302564f, -0.42311478f, -0.22747509f, 0.13375665f, -0.33798844f, 0.15409425f, 0.62915587f, -0.97260368f, -0.03625952f, 0.69901109f, -0.23517157f, 0.67972404f, 0.21827847f, 0.11843839f, -0.01277096f, -0.32323205f, -0.03658739f, -0.31435120f, -0.43483058f, -0.27148122f, 0.77499795f, -0.81867105f, -0.99898142f, -0.67271334f, 1.01101458f, 0.78309727f, 0.88418299f, 0.14850061f, -0.46363044f, -0.06768467f, 0.44989449f, 0.52169687f, 0.44447100f, -0.94505298f, -0.56453151f, -0.72638625f, 0.33930093f, -0.03495500f, -0.02821438f, -1.03944421f, -0.78911626f, 0.30141425f, -0.71277291f, 0.87587613f, 0.76956397f, -0.30962253f, 0.66245174f, -1.33300364f, 0.17070784f, -0.87962157f, -1.13295639f, 0.44555551f, -1.04297531f, 0.75829023f, -0.03878885f, -1.10133481f, 0.55174708f, -0.75599957f, -0.51625776f, -0.05342628f, -0.36763048f, 0.65811008f, -0.70035321f, -1.43289232f, -0.17264232f, 0.24939704f, 0.22910510f, -0.15060824f, 1.66544592f, 0.13985462f, -0.15302825f, 1.13994932f, 0.22160804f, 0.24457796f, 0.47076693f, 0.20962349f, -0.05407842f, 1.70852649f, 1.12710464f, -1.44766259f, 1.25966561f, 0.29169840f, 0.92405045f, -0.75349152f, 1.58519661f, 0.49950320f, 0.67342019f, -0.26927525f, -0.67771661f, 0.40583348f, -0.57214516f, -0.60764003f, 0.25248834f, -0.43436378f, 0.62197685f, 0.76375240f, -1.36287653f, 0.84372342f, -0.35162437f, -1.54409766f, 0.51240182f, 0.20225400f, -0.02268404f, 0.30259740f, -0.04716773f, 0.67471498f, 0.71413249f, -0.46658528f, -0.55675507f, -0.51797366f, 0.41532481f, 1.02307224f, 0.39207584f, 0.08235835f, 0.13141029f, 1.01161790f, 0.15537255f, -0.23259896f, -0.64753038f, -0.84689122f, 0.32039008f, 0.86086088f, -0.10391510f, -0.43225932f, -0.34928828f, 0.56980270f, 0.50981444f, -0.21508121f, 0.16830127f, -1.03776038f, -0.56918675f, 0.11749696f, 0.10682096f, 0.31032404f, -0.09735793f, -0.31645539f, -0.08064906f, -0.20384440f, -0.01380761f, -0.07755496f, -0.60068375f, -0.40580261f, 0.91849798f, 0.08390346f, -0.52695745f, -0.42855084f, -1.03313851f, 0.33994401f, -0.36245525f, -0.69688886f, 0.78342688f, -0.97037071f, -0.48480308f, -0.27335811f, -0.97596729f, -0.23528086f, -0.21981074f, -0.63335001f, -0.88829535f, -0.29944593f, -0.19596440f, 1.19875622f, -0.49918032f, -0.79209542f, 0.10065345f, 1.09010482f, 0.01321086f, 2.13910604f, 1.11112499f, 0.72985148f, 1.00838768f, 0.54585910f, -0.15993178f, 0.26541781f, 0.85031760f, 1.73971391f, 1.52358532f, 0.45130971f, 2.48678756f, 0.19343181f, 0.12072347f, -0.44385901f, 0.52359331f, -0.61506343f, 1.24399531f, -1.18155575f, 0.42505923f, -1.23762119f, -0.31627494f, 0.56629950f, -0.52162617f, 1.12653041f, -0.74352539f, 0.52006209f, -0.15141015f, -0.73136127f, 2.61348605f, 1.23570156f, -0.31728053f, 0.77408886f, 0.27774701f, -0.56392157f, 0.45240438f, 0.84599215f, -0.27067959f, 2.16201377f, -0.71181077f, -0.21107928f, 0.28580725f, -0.41565710f, 0.90941215f, -0.52555859f, -0.16738491f, -0.54808658f, -1.08526981f, 0.06592142f, -1.71680820f, 0.08483648f, 0.01231547f, 0.08641901f, -0.34641835f, -1.74067116f, -0.29078999f, -0.16861805f, 0.45419741f, 1.03918445f, 0.26395023f, -0.40493611f, -0.22315894f, 0.71085817f, -0.88118380f, 1.08953798f, 0.25373179f, 0.14846370f, -0.49211681f, 0.22963192f, 0.22093298f, 0.33117050f, -0.37713605f, 0.35486460f, -0.36020797f, -0.33668989f, 0.11798792f, -0.43228078f, -0.57194173f, 0.20756347f, -0.65037400f, -0.15730046f, -0.13377419f, -0.50105321f, -0.28189951f, -0.22371638f, -0.74900097f, -0.02032955f, -0.61617506f, -1.08462250f, -0.56890285f, -0.93203634f, -0.29669374f, -1.03467071f, 0.21930207f, -0.11985322f, -0.08238113f, 0.27048582f, -0.00008567f, -0.82777119f, -0.17688563f, 0.15913069f, -0.80205917f, 0.10852872f, 1.17724657f, 0.41065979f, 0.81592619f, 0.93737853f, 0.34568185f, 1.63516307f, -0.16411117f, 0.52121985f, 1.25948155f, 0.26429343f, -0.58898950f, 1.38607669f, 0.54334569f, 0.25247136f, 0.33251810f, 2.30225015f, 0.15690750f, 1.08385587f, 0.45817482f, 0.39435270f, 0.23890808f, -0.37897390f, 0.04152094f, -0.93999600f, -0.31861931f, -0.26991367f, 1.17474163f, 0.61628640f, 0.66774678f, -0.05754669f, -0.06096449f, -0.57611501f, 0.53568387f, -0.61293066f, -0.34657979f, -0.32216960f, -0.81506526f, -1.00963306f, 0.16005160f, 0.55423015f, -0.18409301f, 0.05382362f, 0.64573503f, -0.30393291f, 0.31414580f, -0.11099648f, 0.88745099f, 0.38012391f, -0.45172656f, -0.73033381f, -0.79347444f, -0.27273351f, 0.03713746f, -0.46960682f, 0.30829662f, -0.86766714f, -0.20282577f, -0.30651718f, 0.38887522f, -0.91033989f, -0.40621519f, -0.28930432f, 0.68293715f, 0.02102731f, -0.87800777f, -0.56976014f, -1.07986236f, -0.12437577f}; +} // namespace ConvGroupBatch_ExpectedOutput diff --git a/core/test/input_models/references/ConvGroupBatch_input.ref.hxx b/core/test/input_models/references/ConvGroupBatch_input.ref.hxx new file mode 100644 index 0000000..acdfcb0 --- /dev/null +++ b/core/test/input_models/references/ConvGroupBatch_input.ref.hxx @@ -0,0 +1,5 @@ +// Auto-generated by ConvGroupBatchModelGenerator.py - DO NOT EDIT +#pragma once +namespace ConvGroupBatch_Input { + static float data[400] = {0.49671414f, -0.13826430f, 0.64768857f, 1.52302980f, -0.23415338f, -0.23413695f, 1.57921278f, 0.76743472f, -0.46947438f, 0.54256004f, -0.46341768f, -0.46572974f, 0.24196227f, -1.91328025f, -1.72491789f, -0.56228751f, -1.01283109f, 0.31424734f, -0.90802407f, -1.41230369f, 1.46564877f, -0.22577630f, 0.06752820f, -1.42474818f, -0.54438275f, 0.11092259f, -1.15099359f, 0.37569803f, -0.60063869f, -0.29169375f, -0.60170662f, 1.85227823f, -0.01349723f, -1.05771089f, 0.82254493f, -1.22084367f, 0.20886360f, -1.95967007f, -1.32818604f, 0.19686124f, 0.73846656f, 0.17136829f, -0.11564828f, -0.30110368f, -1.47852194f, -0.71984422f, -0.46063876f, 1.05712223f, 0.34361830f, -1.76304018f, 0.32408398f, -0.38508227f, -0.67692202f, 0.61167628f, 1.03099954f, 0.93128014f, -0.83921754f, -0.30921239f, 0.33126342f, 0.97554511f, -0.47917423f, -0.18565898f, -1.10633492f, -1.19620657f, 0.81252581f, 1.35624003f, -0.07201012f, 1.00353289f, 0.36163601f, -0.64511973f, 0.36139560f, 1.53803658f, -0.03582604f, 1.56464362f, -2.61974502f, 0.82190251f, 0.08704707f, -0.29900736f, 0.09176078f, -1.98756886f, -0.21967189f, 0.35711256f, 1.47789407f, -0.51827019f, -0.80849361f, -0.50175703f, 0.91540211f, 0.32875112f, -0.52976018f, 0.51326746f, 0.09707755f, 0.96864498f, -0.70205307f, -0.32766214f, -0.39210814f, -1.46351492f, 0.29612029f, 0.26105526f, 0.00511346f, -0.23458713f, -1.41537070f, -0.42064533f, -0.34271452f, -0.80227727f, -0.16128571f, 0.40405086f, 1.88618588f, 0.17457782f, 0.25755039f, -0.07444592f, -1.91877127f, -0.02651387f, 0.06023021f, 2.46324205f, -0.19236097f, 0.30154735f, -0.03471177f, -1.16867805f, 1.14282286f, 0.75193304f, 0.79103196f, -0.90938747f, 1.40279436f, -1.40185106f, 0.58685708f, 2.19045568f, -0.99053633f, -0.56629771f, 0.09965137f, -0.50347567f, -1.55066347f, 0.06856298f, -1.06230366f, 0.47359243f, -0.91942424f, 1.54993439f, -0.78325331f, -0.32206151f, 0.81351721f, -1.23086429f, 0.22745994f, 1.30714273f, -1.60748327f, 0.18463387f, 0.25988281f, 0.78182286f, -1.23695076f, -1.32045662f, 0.52194154f, 0.29698467f, 0.25049284f, 0.34644821f, -0.68002474f, 0.23225370f, 0.29307246f, -0.71435142f, 1.86577451f, 0.47383294f, -1.19130349f, 0.65655363f, -0.97468168f, 0.78708458f, 1.15859556f, -0.82068235f, 0.96337610f, 0.41278094f, 0.82206017f, 1.89679301f, -0.24538812f, -0.75373614f, -0.88951445f, -0.81581026f, -0.07710171f, 0.34115198f, 0.27669081f, 0.82718325f, 0.01300189f, 1.45353413f, -0.26465684f, 2.72016907f, 0.62566733f, -0.85715753f, -1.07089245f, 0.48247242f, -0.22346279f, 0.71400052f, 0.47323763f, -0.07282891f, -0.84679371f, -1.51484728f, -0.44651496f, 0.85639882f, 0.21409374f, -1.24573874f, 0.17318092f, 0.38531739f, -0.88385743f, 0.15372510f, 0.05820872f, -1.14297032f, 0.35778737f, 0.56078452f, 1.08305120f, 1.05380201f, -1.37766933f, -0.93782502f, 0.51503527f, 0.51378596f, 0.51504767f, 3.85273147f, 0.57089049f, 1.13556564f, 0.95400178f, 0.65139127f, -0.31526923f, 0.75896925f, -0.77282524f, -0.23681861f, -0.48536354f, 0.08187414f, 2.31465864f, -1.86726522f, 0.68626016f, -1.61271584f, -0.47193187f, 1.08895063f, 0.06428002f, -1.07774472f, -0.71530372f, 0.67959774f, -0.73036665f, 0.21645859f, 0.04557184f, -0.65160036f, 2.14394403f, 0.63391900f, -2.02514267f, 0.18645431f, -0.66178644f, 0.85243332f, -0.79252076f, -0.11473644f, 0.50498730f, 0.86575520f, -1.20029640f, -0.33450124f, -0.47494531f, -0.65332925f, 1.76545429f, 0.40498170f, -1.26088393f, 0.91786194f, 2.12215614f, 1.03246522f, -1.51936996f, -0.48423406f, 1.26691115f, -0.70766944f, 0.44381943f, 0.77463406f, -0.92693049f, -0.05952536f, -3.24126744f, -1.02438760f, -0.25256816f, -1.24778318f, 1.63241136f, -1.43014133f, -0.44004449f, 0.13074058f, 1.44127333f, -1.43586218f, 1.16316378f, 0.01023306f, -0.98150867f, 0.46210349f, 0.19905970f, -0.60021687f, 0.06980208f, -0.38531360f, 0.11351734f, 0.66213065f, 1.58601677f, -1.23781550f, 2.13303328f, -1.95208776f, -0.15178509f, 0.58831722f, 0.28099188f, -0.62269950f, -0.20812225f, -0.49300092f, -0.58936477f, 0.84960210f, 0.35701549f, -0.69290960f, 0.89959985f, 0.30729952f, 0.81286210f, 0.62962884f, -0.82899499f, -0.56018102f, 0.74729359f, 0.61037028f, -0.02090159f, 0.11732738f, 1.27766490f, -0.59157139f, 0.54709738f, -0.20219265f, -0.21768120f, 1.09877682f, 0.82541633f, 0.81350964f, 1.30547881f, 0.02100384f, 0.68195295f, -0.31026676f, 0.32416636f, -0.13014306f, 0.09699596f, 0.59515703f, -0.81822067f, 2.09238720f, -1.00601733f, -1.21418858f, 1.15811086f, 0.79166269f, 0.62411982f, 0.62834549f, -0.01224677f, -0.89725435f, 0.07580456f, -0.67716169f, 0.97511971f, -0.14705738f, -0.82549721f, -0.32138583f, 0.41293144f, -0.56372458f, -0.82222039f, 0.24368721f, 0.24496657f, -0.50694317f, -0.47103831f, 0.23204994f, -1.44808435f, -1.40746379f, -0.71844423f, -0.21344715f, 0.31090757f, 1.47535622f, 0.85765964f, -0.15993853f, -0.01901621f, -1.00252938f, -0.01851314f, -0.28865865f, 0.32271856f, -0.82723093f, 0.51934654f, 1.53273892f, -0.10876015f, 0.40171173f, 0.69014400f, -0.40122047f, 0.22409248f, 0.01259240f, 0.09767610f, -0.77300978f, 0.02451017f, 0.49799830f, 1.45114362f, 0.95927083f, 2.15318251f, -0.76734757f, 0.87232065f, 0.18334201f, 2.18980289f, -0.80829829f, -0.83972186f, -0.59939265f, -2.12389565f, -0.52575505f, -0.75913268f, 0.15039378f, 0.34175599f, 1.87617087f, 0.95042384f, -0.57690364f, -0.89841467f, 0.49191916f, -1.32023323f, 1.83145881f, 1.17944014f, -0.46917567f, -1.71313453f, 1.35387242f, -0.11453985f, 1.23781633f}; +} // namespace ConvGroupBatch_Input