diff --git a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx index e415cce..c6b8027 100644 --- a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx +++ b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx @@ -115,6 +115,15 @@ #include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx" #include "input_models/references/ConvWithAsymmetricPadding.ref.hxx" +#include "ConvGroup2_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/ConvGroup2.ref.hxx" + +#include "ConvGroup4_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/ConvGroup4.ref.hxx" + +#include "ConvBatch4Group2_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/ConvBatch4Group2.ref.hxx" + #include "BatchNorm_FromONNX_GPU_ALPAKA.hxx" #include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx" @@ -2200,6 +2209,108 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding) } } +TEST_F(SofieAlpakaTest, ConvGroup2) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + std::vector input(100); + std::iota(input.begin(), input.end(), 0.0f); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{input.size()})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{input.size()})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{sizeof(ConvGroup2_ExpectedOutput::correct) / sizeof(float)})); + + { + SOFIE_ConvGroup2::Session session("ConvGroup2_FromONNX_GPU_ALPAKA.dat"); + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = ConvGroup2_ExpectedOutput::correct; + constexpr size_t nOut = sizeof(ConvGroup2_ExpectedOutput::correct) / sizeof(float); + + for (size_t i = 0; i < nOut; ++i) + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; +} + +TEST_F(SofieAlpakaTest, ConvGroup4) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + std::vector input(100); + std::iota(input.begin(), input.end(), 0.0f); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{input.size()})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{input.size()})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{sizeof(ConvGroup4_ExpectedOutput::correct) / sizeof(float)})); + + { + SOFIE_ConvGroup4::Session session("ConvGroup4_FromONNX_GPU_ALPAKA.dat"); + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = ConvGroup4_ExpectedOutput::correct; + constexpr size_t nOut = sizeof(ConvGroup4_ExpectedOutput::correct) / sizeof(float); + + for (size_t i = 0; i < nOut; ++i) + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; +} + +TEST_F(SofieAlpakaTest, ConvBatch4Group2) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + std::vector input(400); + std::iota(input.begin(), input.end(), 0.0f); + + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{input.size()})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{input.size()})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{sizeof(ConvBatch4Group2_ExpectedOutput::correct) / sizeof(float)})); + + { + SOFIE_ConvBatch4Group2::Session session("ConvBatch4Group2_FromONNX_GPU_ALPAKA.dat"); + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = ConvBatch4Group2_ExpectedOutput::correct; + constexpr size_t nOut = sizeof(ConvBatch4Group2_ExpectedOutput::correct) / sizeof(float); + + for (size_t i = 0; i < nOut; ++i) + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; +} + TEST_F(SofieAlpakaTest, BatchNormalization) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; diff --git a/core/test/input_models/ConvBatch4Group2.onnx b/core/test/input_models/ConvBatch4Group2.onnx new file mode 100644 index 0000000..d39b881 Binary files /dev/null and b/core/test/input_models/ConvBatch4Group2.onnx differ diff --git a/core/test/input_models/ConvGroup2.onnx b/core/test/input_models/ConvGroup2.onnx new file mode 100644 index 0000000..cda456d Binary files /dev/null and b/core/test/input_models/ConvGroup2.onnx differ diff --git a/core/test/input_models/ConvGroup4.onnx b/core/test/input_models/ConvGroup4.onnx new file mode 100644 index 0000000..e028850 Binary files /dev/null and b/core/test/input_models/ConvGroup4.onnx differ diff --git a/core/test/input_models/references/ConvBatch4Group2.ref.hxx b/core/test/input_models/references/ConvBatch4Group2.ref.hxx new file mode 100644 index 0000000..28d0021 --- /dev/null +++ b/core/test/input_models/references/ConvBatch4Group2.ref.hxx @@ -0,0 +1,3 @@ +namespace ConvBatch4Group2_ExpectedOutput { +float correct[] = {1908.000000f, 2840.000000f, 2972.000000f, 3104.000000f, 2036.000000f, 2958.000000f, 4368.000000f, 4539.000000f, 4710.000000f, 3066.000000f, 3558.000000f, 5223.000000f, 5394.000000f, 5565.000000f, 3606.000000f, 4158.000000f, 6078.000000f, 6249.000000f, 6420.000000f, 4146.000000f, 2556.000000f, 3704.000000f, 3800.000000f, 3896.000000f, 2492.000000f, 4140.000000f, 6296.000000f, 6644.000000f, 6992.000000f, 4700.000000f, 6846.000000f, 10362.000000f, 10857.000000f, 11352.000000f, 7602.000000f, 8526.000000f, 12837.000000f, 13332.000000f, 13827.000000f, 9222.000000f, 10206.000000f, 15312.000000f, 15807.000000f, 16302.000000f, 10842.000000f, 6948.000000f, 10400.000000f, 10712.000000f, 11024.000000f, 7316.000000f, 25372.000000f, 37952.000000f, 38516.000000f, 39080.000000f, 25964.000000f, 38334.000000f, 57306.000000f, 58125.000000f, 58944.000000f, 39138.000000f, 41094.000000f, 61401.000000f, 62220.000000f, 63039.000000f, 41838.000000f, 43854.000000f, 65496.000000f, 66315.000000f, 67134.000000f, 44538.000000f, 29140.000000f, 43496.000000f, 44024.000000f, 44552.000000f, 29540.000000f, 34804.000000f, 52208.000000f, 52988.000000f, 53768.000000f, 35828.000000f, 53022.000000f, 79500.000000f, 80643.000000f, 81786.000000f, 54474.000000f, 56862.000000f, 85215.000000f, 86358.000000f, 87501.000000f, 58254.000000f, 60702.000000f, 90930.000000f, 92073.000000f, 93216.000000f, 62034.000000f, 40732.000000f, 60992.000000f, 61736.000000f, 62480.000000f, 41564.000000f, 11108.000000f, 16040.000000f, 16172.000000f, 16304.000000f, 10436.000000f, 14958.000000f, 21468.000000f, 21639.000000f, 21810.000000f, 13866.000000f, 15558.000000f, 22323.000000f, 22494.000000f, 22665.000000f, 14406.000000f, 16158.000000f, 23178.000000f, 23349.000000f, 23520.000000f, 14946.000000f, 9356.000000f, 13304.000000f, 13400.000000f, 13496.000000f, 8492.000000f, 27740.000000f, 41096.000000f, 41444.000000f, 41792.000000f, 27500.000000f, 40446.000000f, 59862.000000f, 60357.000000f, 60852.000000f, 40002.000000f, 42126.000000f, 62337.000000f, 62832.000000f, 63327.000000f, 41622.000000f, 43806.000000f, 64812.000000f, 65307.000000f, 65802.000000f, 43242.000000f, 28148.000000f, 41600.000000f, 41912.000000f, 42224.000000f, 27716.000000f, 63372.000000f, 94352.000000f, 94916.000000f, 95480.000000f, 63164.000000f, 93534.000000f, 139206.000000f, 140025.000000f, 140844.000000f, 93138.000000f, 96294.000000f, 143301.000000f, 144120.000000f, 144939.000000f, 95838.000000f, 99054.000000f, 147396.000000f, 148215.000000f, 149034.000000f, 98538.000000f, 64740.000000f, 96296.000000f, 96824.000000f, 97352.000000f, 64340.000000f, 87204.000000f, 130208.000000f, 130988.000000f, 131768.000000f, 87428.000000f, 129822.000000f, 193800.000000f, 194943.000000f, 196086.000000f, 130074.000000f, 133662.000000f, 199515.000000f, 200658.000000f, 201801.000000f, 133854.000000f, 137502.000000f, 205230.000000f, 206373.000000f, 207516.000000f, 137634.000000f, 90732.000000f, 135392.000000f, 136136.000000f, 136880.000000f, 90764.000000f, 20308.000000f, 29240.000000f, 29372.000000f, 29504.000000f, 18836.000000f, 26958.000000f, 38568.000000f, 38739.000000f, 38910.000000f, 24666.000000f, 27558.000000f, 39423.000000f, 39594.000000f, 39765.000000f, 25206.000000f, 28158.000000f, 40278.000000f, 40449.000000f, 40620.000000f, 25746.000000f, 16156.000000f, 22904.000000f, 23000.000000f, 23096.000000f, 14492.000000f, 51340.000000f, 75896.000000f, 76244.000000f, 76592.000000f, 50300.000000f, 74046.000000f, 109362.000000f, 109857.000000f, 110352.000000f, 72402.000000f, 75726.000000f, 111837.000000f, 112332.000000f, 112827.000000f, 74022.000000f, 77406.000000f, 114312.000000f, 114807.000000f, 115302.000000f, 75642.000000f, 49348.000000f, 72800.000000f, 73112.000000f, 73424.000000f, 48116.000000f, 101372.000000f, 150752.000000f, 151316.000000f, 151880.000000f, 100364.000000f, 148734.000000f, 221106.000000f, 221925.000000f, 222744.000000f, 147138.000000f, 151494.000000f, 225201.000000f, 226020.000000f, 226839.000000f, 149838.000000f, 154254.000000f, 229296.000000f, 230115.000000f, 230934.000000f, 152538.000000f, 100340.000000f, 149096.000000f, 149624.000000f, 150152.000000f, 99140.000000f, 139604.000000f, 208208.000000f, 208988.000000f, 209768.000000f, 139028.000000f, 206622.000000f, 308100.000000f, 309243.000000f, 310386.000000f, 205674.000000f, 210462.000000f, 313815.000000f, 314958.000000f, 316101.000000f, 209454.000000f, 214302.000000f, 319530.000000f, 320673.000000f, 321816.000000f, 213234.000000f, 140732.000000f, 209792.000000f, 210536.000000f, 211280.000000f, 139964.000000f, 29508.000000f, 42440.000000f, 42572.000000f, 42704.000000f, 27236.000000f, 38958.000000f, 55668.000000f, 55839.000000f, 56010.000000f, 35466.000000f, 39558.000000f, 56523.000000f, 56694.000000f, 56865.000000f, 36006.000000f, 40158.000000f, 57378.000000f, 57549.000000f, 57720.000000f, 36546.000000f, 22956.000000f, 32504.000000f, 32600.000000f, 32696.000000f, 20492.000000f, 74940.000000f, 110696.000000f, 111044.000000f, 111392.000000f, 73100.000000f, 107646.000000f, 158862.000000f, 159357.000000f, 159852.000000f, 104802.000000f, 109326.000000f, 161337.000000f, 161832.000000f, 162327.000000f, 106422.000000f, 111006.000000f, 163812.000000f, 164307.000000f, 164802.000000f, 108042.000000f, 70548.000000f, 104000.000000f, 104312.000000f, 104624.000000f, 68516.000000f, 139372.000000f, 207152.000000f, 207716.000000f, 208280.000000f, 137564.000000f, 203934.000000f, 303006.000000f, 303825.000000f, 304644.000000f, 201138.000000f, 206694.000000f, 307101.000000f, 307920.000000f, 308739.000000f, 203838.000000f, 209454.000000f, 311196.000000f, 312015.000000f, 312834.000000f, 206538.000000f, 135940.000000f, 201896.000000f, 202424.000000f, 202952.000000f, 133940.000000f, 192004.000000f, 286208.000000f, 286988.000000f, 287768.000000f, 190628.000000f, 283422.000000f, 422400.000000f, 423543.000000f, 424686.000000f, 281274.000000f, 287262.000000f, 428115.000000f, 429258.000000f, 430401.000000f, 285054.000000f, 291102.000000f, 433830.000000f, 434973.000000f, 436116.000000f, 288834.000000f, 190732.000000f, 284192.000000f, 284936.000000f, 285680.000000f, 189164.000000f}; +} // namespace ConvBatch4Group2_ExpectedOutput diff --git a/core/test/input_models/references/ConvGroup2.ref.hxx b/core/test/input_models/references/ConvGroup2.ref.hxx new file mode 100644 index 0000000..3a727fb --- /dev/null +++ b/core/test/input_models/references/ConvGroup2.ref.hxx @@ -0,0 +1,3 @@ +namespace ConvGroup2_ExpectedOutput { +float correct[] = {1908.000000f, 2840.000000f, 2972.000000f, 3104.000000f, 2036.000000f, 2958.000000f, 4368.000000f, 4539.000000f, 4710.000000f, 3066.000000f, 3558.000000f, 5223.000000f, 5394.000000f, 5565.000000f, 3606.000000f, 4158.000000f, 6078.000000f, 6249.000000f, 6420.000000f, 4146.000000f, 2556.000000f, 3704.000000f, 3800.000000f, 3896.000000f, 2492.000000f, 4140.000000f, 6296.000000f, 6644.000000f, 6992.000000f, 4700.000000f, 6846.000000f, 10362.000000f, 10857.000000f, 11352.000000f, 7602.000000f, 8526.000000f, 12837.000000f, 13332.000000f, 13827.000000f, 9222.000000f, 10206.000000f, 15312.000000f, 15807.000000f, 16302.000000f, 10842.000000f, 6948.000000f, 10400.000000f, 10712.000000f, 11024.000000f, 7316.000000f, 25372.000000f, 37952.000000f, 38516.000000f, 39080.000000f, 25964.000000f, 38334.000000f, 57306.000000f, 58125.000000f, 58944.000000f, 39138.000000f, 41094.000000f, 61401.000000f, 62220.000000f, 63039.000000f, 41838.000000f, 43854.000000f, 65496.000000f, 66315.000000f, 67134.000000f, 44538.000000f, 29140.000000f, 43496.000000f, 44024.000000f, 44552.000000f, 29540.000000f, 34804.000000f, 52208.000000f, 52988.000000f, 53768.000000f, 35828.000000f, 53022.000000f, 79500.000000f, 80643.000000f, 81786.000000f, 54474.000000f, 56862.000000f, 85215.000000f, 86358.000000f, 87501.000000f, 58254.000000f, 60702.000000f, 90930.000000f, 92073.000000f, 93216.000000f, 62034.000000f, 40732.000000f, 60992.000000f, 61736.000000f, 62480.000000f, 41564.000000f}; +} // namespace ConvGroup2_ExpectedOutput diff --git a/core/test/input_models/references/ConvGroup4.ref.hxx b/core/test/input_models/references/ConvGroup4.ref.hxx new file mode 100644 index 0000000..5bf7763 --- /dev/null +++ b/core/test/input_models/references/ConvGroup4.ref.hxx @@ -0,0 +1,3 @@ +namespace ConvGroup4_ExpectedOutput { +float correct[] = {100.000000f, 163.000000f, 202.000000f, 241.000000f, 160.000000f, 243.000000f, 366.000000f, 411.000000f, 456.000000f, 291.000000f, 408.000000f, 591.000000f, 636.000000f, 681.000000f, 426.000000f, 573.000000f, 816.000000f, 861.000000f, 906.000000f, 561.000000f, 304.000000f, 415.000000f, 436.000000f, 457.000000f, 268.000000f, 1808.000000f, 2677.000000f, 2770.000000f, 2863.000000f, 1876.000000f, 2715.000000f, 4002.000000f, 4128.000000f, 4254.000000f, 2775.000000f, 3150.000000f, 4632.000000f, 4758.000000f, 4884.000000f, 3180.000000f, 3585.000000f, 5262.000000f, 5388.000000f, 5514.000000f, 3585.000000f, 2252.000000f, 3289.000000f, 3364.000000f, 3439.000000f, 2224.000000f, 5316.000000f, 7891.000000f, 8038.000000f, 8185.000000f, 5392.000000f, 7887.000000f, 11688.000000f, 11895.000000f, 12102.000000f, 7959.000000f, 8592.000000f, 12723.000000f, 12930.000000f, 13137.000000f, 8634.000000f, 9297.000000f, 13758.000000f, 13965.000000f, 14172.000000f, 9309.000000f, 6000.000000f, 8863.000000f, 8992.000000f, 9121.000000f, 5980.000000f, 10624.000000f, 15805.000000f, 16006.000000f, 16207.000000f, 10708.000000f, 15759.000000f, 23424.000000f, 23712.000000f, 24000.000000f, 15843.000000f, 16734.000000f, 24864.000000f, 25152.000000f, 25440.000000f, 16788.000000f, 17709.000000f, 26304.000000f, 26592.000000f, 26880.000000f, 17733.000000f, 11548.000000f, 17137.000000f, 17320.000000f, 17503.000000f, 11536.000000f}; +} // namespace ConvGroup4_ExpectedOutput