Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@
#include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvWithAsymmetricPadding.ref.hxx"

#include "ConvGroup2_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvGroup2.ref.hxx"

#include "ConvGroup4_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvGroup4.ref.hxx"

#include "ConvBatch4Group2_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvBatch4Group2.ref.hxx"

#include "BatchNorm_FromONNX_GPU_ALPAKA.hxx"
#include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx"

Expand Down Expand Up @@ -2200,6 +2209,108 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding)
}
}

TEST_F(SofieAlpakaTest, ConvGroup2)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

std::vector<float> input(100);
std::iota(input.begin(), input.end(), 0.0f);

auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{input.size()}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{input.size()}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

auto result_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{sizeof(ConvGroup2_ExpectedOutput::correct) / sizeof(float)}));

{
SOFIE_ConvGroup2::Session<alpaka::TagGpuCudaRt> session("ConvGroup2_FromONNX_GPU_ALPAKA.dat");
auto result = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();
alpaka::memcpy(queue, result_h, result);
alpaka::wait(queue);
}

float* res_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(result_h));
float* correct = ConvGroup2_ExpectedOutput::correct;
constexpr size_t nOut = sizeof(ConvGroup2_ExpectedOutput::correct) / sizeof(float);

for (size_t i = 0; i < nOut; ++i)
EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i;
}

TEST_F(SofieAlpakaTest, ConvGroup4)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

std::vector<float> input(100);
std::iota(input.begin(), input.end(), 0.0f);

auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{input.size()}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{input.size()}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

auto result_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{sizeof(ConvGroup4_ExpectedOutput::correct) / sizeof(float)}));

{
SOFIE_ConvGroup4::Session<alpaka::TagGpuCudaRt> session("ConvGroup4_FromONNX_GPU_ALPAKA.dat");
auto result = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();
alpaka::memcpy(queue, result_h, result);
alpaka::wait(queue);
}

float* res_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(result_h));
float* correct = ConvGroup4_ExpectedOutput::correct;
constexpr size_t nOut = sizeof(ConvGroup4_ExpectedOutput::correct) / sizeof(float);

for (size_t i = 0; i < nOut; ++i)
EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i;
}

TEST_F(SofieAlpakaTest, ConvBatch4Group2)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

std::vector<float> input(400);
std::iota(input.begin(), input.end(), 0.0f);

auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{input.size()}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < input.size(); ++i) input_ptr[i] = input[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{input.size()}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

auto result_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{sizeof(ConvBatch4Group2_ExpectedOutput::correct) / sizeof(float)}));

{
SOFIE_ConvBatch4Group2::Session<alpaka::TagGpuCudaRt> session("ConvBatch4Group2_FromONNX_GPU_ALPAKA.dat");
auto result = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();
alpaka::memcpy(queue, result_h, result);
alpaka::wait(queue);
}

float* res_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(result_h));
float* correct = ConvBatch4Group2_ExpectedOutput::correct;
constexpr size_t nOut = sizeof(ConvBatch4Group2_ExpectedOutput::correct) / sizeof(float);

for (size_t i = 0; i < nOut; ++i)
EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i;
}

TEST_F(SofieAlpakaTest, BatchNormalization)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down
Binary file added core/test/input_models/ConvBatch4Group2.onnx
Binary file not shown.
Binary file added core/test/input_models/ConvGroup2.onnx
Binary file not shown.
Binary file added core/test/input_models/ConvGroup4.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions core/test/input_models/references/ConvBatch4Group2.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace ConvBatch4Group2_ExpectedOutput {
float correct[] = {1908.000000f, 2840.000000f, 2972.000000f, 3104.000000f, 2036.000000f, 2958.000000f, 4368.000000f, 4539.000000f, 4710.000000f, 3066.000000f, 3558.000000f, 5223.000000f, 5394.000000f, 5565.000000f, 3606.000000f, 4158.000000f, 6078.000000f, 6249.000000f, 6420.000000f, 4146.000000f, 2556.000000f, 3704.000000f, 3800.000000f, 3896.000000f, 2492.000000f, 4140.000000f, 6296.000000f, 6644.000000f, 6992.000000f, 4700.000000f, 6846.000000f, 10362.000000f, 10857.000000f, 11352.000000f, 7602.000000f, 8526.000000f, 12837.000000f, 13332.000000f, 13827.000000f, 9222.000000f, 10206.000000f, 15312.000000f, 15807.000000f, 16302.000000f, 10842.000000f, 6948.000000f, 10400.000000f, 10712.000000f, 11024.000000f, 7316.000000f, 25372.000000f, 37952.000000f, 38516.000000f, 39080.000000f, 25964.000000f, 38334.000000f, 57306.000000f, 58125.000000f, 58944.000000f, 39138.000000f, 41094.000000f, 61401.000000f, 62220.000000f, 63039.000000f, 41838.000000f, 43854.000000f, 65496.000000f, 66315.000000f, 67134.000000f, 44538.000000f, 29140.000000f, 43496.000000f, 44024.000000f, 44552.000000f, 29540.000000f, 34804.000000f, 52208.000000f, 52988.000000f, 53768.000000f, 35828.000000f, 53022.000000f, 79500.000000f, 80643.000000f, 81786.000000f, 54474.000000f, 56862.000000f, 85215.000000f, 86358.000000f, 87501.000000f, 58254.000000f, 60702.000000f, 90930.000000f, 92073.000000f, 93216.000000f, 62034.000000f, 40732.000000f, 60992.000000f, 61736.000000f, 62480.000000f, 41564.000000f, 11108.000000f, 16040.000000f, 16172.000000f, 16304.000000f, 10436.000000f, 14958.000000f, 21468.000000f, 21639.000000f, 21810.000000f, 13866.000000f, 15558.000000f, 22323.000000f, 22494.000000f, 22665.000000f, 14406.000000f, 16158.000000f, 23178.000000f, 23349.000000f, 23520.000000f, 14946.000000f, 9356.000000f, 13304.000000f, 13400.000000f, 13496.000000f, 8492.000000f, 27740.000000f, 41096.000000f, 41444.000000f, 41792.000000f, 27500.000000f, 40446.000000f, 59862.000000f, 60357.000000f, 60852.000000f, 40002.000000f, 42126.000000f, 62337.000000f, 62832.000000f, 63327.000000f, 41622.000000f, 43806.000000f, 64812.000000f, 65307.000000f, 65802.000000f, 43242.000000f, 28148.000000f, 41600.000000f, 41912.000000f, 42224.000000f, 27716.000000f, 63372.000000f, 94352.000000f, 94916.000000f, 95480.000000f, 63164.000000f, 93534.000000f, 139206.000000f, 140025.000000f, 140844.000000f, 93138.000000f, 96294.000000f, 143301.000000f, 144120.000000f, 144939.000000f, 95838.000000f, 99054.000000f, 147396.000000f, 148215.000000f, 149034.000000f, 98538.000000f, 64740.000000f, 96296.000000f, 96824.000000f, 97352.000000f, 64340.000000f, 87204.000000f, 130208.000000f, 130988.000000f, 131768.000000f, 87428.000000f, 129822.000000f, 193800.000000f, 194943.000000f, 196086.000000f, 130074.000000f, 133662.000000f, 199515.000000f, 200658.000000f, 201801.000000f, 133854.000000f, 137502.000000f, 205230.000000f, 206373.000000f, 207516.000000f, 137634.000000f, 90732.000000f, 135392.000000f, 136136.000000f, 136880.000000f, 90764.000000f, 20308.000000f, 29240.000000f, 29372.000000f, 29504.000000f, 18836.000000f, 26958.000000f, 38568.000000f, 38739.000000f, 38910.000000f, 24666.000000f, 27558.000000f, 39423.000000f, 39594.000000f, 39765.000000f, 25206.000000f, 28158.000000f, 40278.000000f, 40449.000000f, 40620.000000f, 25746.000000f, 16156.000000f, 22904.000000f, 23000.000000f, 23096.000000f, 14492.000000f, 51340.000000f, 75896.000000f, 76244.000000f, 76592.000000f, 50300.000000f, 74046.000000f, 109362.000000f, 109857.000000f, 110352.000000f, 72402.000000f, 75726.000000f, 111837.000000f, 112332.000000f, 112827.000000f, 74022.000000f, 77406.000000f, 114312.000000f, 114807.000000f, 115302.000000f, 75642.000000f, 49348.000000f, 72800.000000f, 73112.000000f, 73424.000000f, 48116.000000f, 101372.000000f, 150752.000000f, 151316.000000f, 151880.000000f, 100364.000000f, 148734.000000f, 221106.000000f, 221925.000000f, 222744.000000f, 147138.000000f, 151494.000000f, 225201.000000f, 226020.000000f, 226839.000000f, 149838.000000f, 154254.000000f, 229296.000000f, 230115.000000f, 230934.000000f, 152538.000000f, 100340.000000f, 149096.000000f, 149624.000000f, 150152.000000f, 99140.000000f, 139604.000000f, 208208.000000f, 208988.000000f, 209768.000000f, 139028.000000f, 206622.000000f, 308100.000000f, 309243.000000f, 310386.000000f, 205674.000000f, 210462.000000f, 313815.000000f, 314958.000000f, 316101.000000f, 209454.000000f, 214302.000000f, 319530.000000f, 320673.000000f, 321816.000000f, 213234.000000f, 140732.000000f, 209792.000000f, 210536.000000f, 211280.000000f, 139964.000000f, 29508.000000f, 42440.000000f, 42572.000000f, 42704.000000f, 27236.000000f, 38958.000000f, 55668.000000f, 55839.000000f, 56010.000000f, 35466.000000f, 39558.000000f, 56523.000000f, 56694.000000f, 56865.000000f, 36006.000000f, 40158.000000f, 57378.000000f, 57549.000000f, 57720.000000f, 36546.000000f, 22956.000000f, 32504.000000f, 32600.000000f, 32696.000000f, 20492.000000f, 74940.000000f, 110696.000000f, 111044.000000f, 111392.000000f, 73100.000000f, 107646.000000f, 158862.000000f, 159357.000000f, 159852.000000f, 104802.000000f, 109326.000000f, 161337.000000f, 161832.000000f, 162327.000000f, 106422.000000f, 111006.000000f, 163812.000000f, 164307.000000f, 164802.000000f, 108042.000000f, 70548.000000f, 104000.000000f, 104312.000000f, 104624.000000f, 68516.000000f, 139372.000000f, 207152.000000f, 207716.000000f, 208280.000000f, 137564.000000f, 203934.000000f, 303006.000000f, 303825.000000f, 304644.000000f, 201138.000000f, 206694.000000f, 307101.000000f, 307920.000000f, 308739.000000f, 203838.000000f, 209454.000000f, 311196.000000f, 312015.000000f, 312834.000000f, 206538.000000f, 135940.000000f, 201896.000000f, 202424.000000f, 202952.000000f, 133940.000000f, 192004.000000f, 286208.000000f, 286988.000000f, 287768.000000f, 190628.000000f, 283422.000000f, 422400.000000f, 423543.000000f, 424686.000000f, 281274.000000f, 287262.000000f, 428115.000000f, 429258.000000f, 430401.000000f, 285054.000000f, 291102.000000f, 433830.000000f, 434973.000000f, 436116.000000f, 288834.000000f, 190732.000000f, 284192.000000f, 284936.000000f, 285680.000000f, 189164.000000f};
} // namespace ConvBatch4Group2_ExpectedOutput
3 changes: 3 additions & 0 deletions core/test/input_models/references/ConvGroup2.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace ConvGroup2_ExpectedOutput {
float correct[] = {1908.000000f, 2840.000000f, 2972.000000f, 3104.000000f, 2036.000000f, 2958.000000f, 4368.000000f, 4539.000000f, 4710.000000f, 3066.000000f, 3558.000000f, 5223.000000f, 5394.000000f, 5565.000000f, 3606.000000f, 4158.000000f, 6078.000000f, 6249.000000f, 6420.000000f, 4146.000000f, 2556.000000f, 3704.000000f, 3800.000000f, 3896.000000f, 2492.000000f, 4140.000000f, 6296.000000f, 6644.000000f, 6992.000000f, 4700.000000f, 6846.000000f, 10362.000000f, 10857.000000f, 11352.000000f, 7602.000000f, 8526.000000f, 12837.000000f, 13332.000000f, 13827.000000f, 9222.000000f, 10206.000000f, 15312.000000f, 15807.000000f, 16302.000000f, 10842.000000f, 6948.000000f, 10400.000000f, 10712.000000f, 11024.000000f, 7316.000000f, 25372.000000f, 37952.000000f, 38516.000000f, 39080.000000f, 25964.000000f, 38334.000000f, 57306.000000f, 58125.000000f, 58944.000000f, 39138.000000f, 41094.000000f, 61401.000000f, 62220.000000f, 63039.000000f, 41838.000000f, 43854.000000f, 65496.000000f, 66315.000000f, 67134.000000f, 44538.000000f, 29140.000000f, 43496.000000f, 44024.000000f, 44552.000000f, 29540.000000f, 34804.000000f, 52208.000000f, 52988.000000f, 53768.000000f, 35828.000000f, 53022.000000f, 79500.000000f, 80643.000000f, 81786.000000f, 54474.000000f, 56862.000000f, 85215.000000f, 86358.000000f, 87501.000000f, 58254.000000f, 60702.000000f, 90930.000000f, 92073.000000f, 93216.000000f, 62034.000000f, 40732.000000f, 60992.000000f, 61736.000000f, 62480.000000f, 41564.000000f};
} // namespace ConvGroup2_ExpectedOutput
3 changes: 3 additions & 0 deletions core/test/input_models/references/ConvGroup4.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace ConvGroup4_ExpectedOutput {
float correct[] = {100.000000f, 163.000000f, 202.000000f, 241.000000f, 160.000000f, 243.000000f, 366.000000f, 411.000000f, 456.000000f, 291.000000f, 408.000000f, 591.000000f, 636.000000f, 681.000000f, 426.000000f, 573.000000f, 816.000000f, 861.000000f, 906.000000f, 561.000000f, 304.000000f, 415.000000f, 436.000000f, 457.000000f, 268.000000f, 1808.000000f, 2677.000000f, 2770.000000f, 2863.000000f, 1876.000000f, 2715.000000f, 4002.000000f, 4128.000000f, 4254.000000f, 2775.000000f, 3150.000000f, 4632.000000f, 4758.000000f, 4884.000000f, 3180.000000f, 3585.000000f, 5262.000000f, 5388.000000f, 5514.000000f, 3585.000000f, 2252.000000f, 3289.000000f, 3364.000000f, 3439.000000f, 2224.000000f, 5316.000000f, 7891.000000f, 8038.000000f, 8185.000000f, 5392.000000f, 7887.000000f, 11688.000000f, 11895.000000f, 12102.000000f, 7959.000000f, 8592.000000f, 12723.000000f, 12930.000000f, 13137.000000f, 8634.000000f, 9297.000000f, 13758.000000f, 13965.000000f, 14172.000000f, 9309.000000f, 6000.000000f, 8863.000000f, 8992.000000f, 9121.000000f, 5980.000000f, 10624.000000f, 15805.000000f, 16006.000000f, 16207.000000f, 10708.000000f, 15759.000000f, 23424.000000f, 23712.000000f, 24000.000000f, 15843.000000f, 16734.000000f, 24864.000000f, 25152.000000f, 25440.000000f, 16788.000000f, 17709.000000f, 26304.000000f, 26592.000000f, 26880.000000f, 17733.000000f, 11548.000000f, 17137.000000f, 17320.000000f, 17503.000000f, 11536.000000f};
} // namespace ConvGroup4_ExpectedOutput