diff --git a/core/inc/SOFIE/ROperator_Conv.hxx b/core/inc/SOFIE/ROperator_Conv.hxx index 835a0ff..2cdbcbc 100644 --- a/core/inc/SOFIE/ROperator_Conv.hxx +++ b/core/inc/SOFIE/ROperator_Conv.hxx @@ -317,7 +317,9 @@ public: } std::vector shape1 = {fShapeW[0], fShapeW[1], kernelSize}; - std::vector shape2 = {Dim{fShapeW[1]}, Dim{kernelSize}, channelDim }; + // _xcol holds the im2col of every batch sample, so the non-grouped GPU path can + // run a single strided-batched GEMM over all samples (each gets its own slice). + std::vector shape2 = {fShapeX[0], Dim{fShapeW[1]}, Dim{kernelSize}, channelDim }; model.AddIntermediateTensor(fNX +"_f", ConvertStringToType(fType), shape1 ); model.AddIntermediateTensor(fNX +"_xcol", ConvertStringToType(fType), shape2 ); convK = fNX +"_f"; @@ -874,22 +876,21 @@ public: // Step 3 + 4: Im2Col then GEMM — structure differs for grouped vs non-grouped // ----------------------------------------------------------------------- if (fAttrGroup == 1) { - // Non-grouped: single im2col per batch, then GEMM - out << SP << SP << "// Step 3: im2col\n"; + // Non-grouped: im2col this sample into its own _xcol slice (slice n). The + // single strided-batched GEMM over all samples is issued after the loop. out << SP << SP << "{\n"; out << SP << SP << SP << "auto const elementsPerThread_im2col = Vec::all(static_cast(1));\n"; out << SP << SP << SP << "auto const elementsPerGrid_im2col = Vec::all(Idx{" << colElements << "});\n"; out << SP << SP << SP << "auto const workDiv_im2col = sofie_workdiv(elementsPerGrid_im2col);\n"; out << SP << SP << SP << "alpaka::exec(queue, workDiv_im2col, im2colKernel_" << opName << ", alpaka::getPtrNative(deviceBuf_" << fNX << ") + x_offset" - << ", alpaka::getPtrNative(deviceBuf_" << imcol << ")" + << ", alpaka::getPtrNative(deviceBuf_" << imcol << ") + n * " << colElements << "u" << ", static_cast(" << colElements << "));\n"; - out << SP << SP << SP << "alpaka::wait(queue);\n"; out << SP << SP << "}\n\n"; if (!fNB.empty()) { size_t biasElements = gemm_n * gemm_m; - out << SP << SP << "// Step 4a: broadcast bias into output slice\n"; + out << SP << SP << "// broadcast bias into this sample's output slice\n"; out << SP << SP << "{\n"; out << SP << SP << SP << "auto const elementsPerThread_bias = Vec::all(static_cast(1));\n"; out << SP << SP << SP << "auto const elementsPerGrid_bias = Vec::all(Idx{" << biasElements << "});\n"; @@ -898,24 +899,8 @@ public: << ", alpaka::getPtrNative(deviceBuf_" << fNB << ")" << ", alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset" << ", static_cast(" << biasElements << "));\n"; - out << SP << SP << SP << "alpaka::wait(queue);\n"; out << SP << SP << "}\n\n"; - out << SP << SP << "// Step 4b: GEMM beta=1 accumulates onto bias-initialised output\n"; - out << SP << SP << "blas.matmul('n', 'n', " - << gemm_m << ", " << gemm_n << ", " << gemm_k - << ", 1.0f, alpaka::getPtrNative(deviceBuf_" << imcol << ")" - << ", alpaka::getPtrNative(deviceBuf_" << convK << ")" - << ", 1.0f, alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset);\n\n"; - } else { - out << SP << SP << "// Step 4: GEMM beta=0 (no bias)\n"; - out << SP << SP << "blas.matmul('n', 'n', " - << gemm_m << ", " << gemm_n << ", " << gemm_k - << ", 1.0f, alpaka::getPtrNative(deviceBuf_" << imcol << ")" - << ", alpaka::getPtrNative(deviceBuf_" << convK << ")" - << ", 0.0f, alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset);\n\n"; } - // Wait for GEMM to finish before next batch overwrites the shared _xcol buffer. - out << SP << SP << "alpaka::wait(queue);\n\n"; } else { // Grouped convolution: im2col and GEMM per group with group-adjusted input pointer. @@ -970,6 +955,21 @@ public: } out << SP << "}\n"; // end batch loop + + // Non-grouped: replace the per-sample matmul loop with one strided-batched GEMM. + // Each sample reads its own _xcol slice (strideA = colElements) and writes its own + // output block (strideC = gemm_n*gemm_m); the weight _f is shared, so strideB = 0. + if (fAttrGroup == 1) { + std::string convBeta = fNB.empty() ? "0.0f" : "1.0f"; + out << SP << "alpaka::wait(queue);\n"; + out << SP << "blas.gemmStridedBatched('n', 'n', " + << gemm_m << ", " << gemm_n << ", " << gemm_k << ", 1.0f, " + << "alpaka::getPtrNative(deviceBuf_" << imcol << "), " << gemm_m << ", " << colElements << ", " + << "alpaka::getPtrNative(deviceBuf_" << convK << "), " << gemm_k << ", 0, " + << convBeta << ", alpaka::getPtrNative(deviceBuf_" << fNY << "), " + << gemm_m << ", " << gemm_n * gemm_m << ", " << bsize << ");\n"; + out << SP << "alpaka::wait(queue);\n"; + } return out.str(); } @@ -979,6 +979,9 @@ public: std::string GetBlasConfig(){ + // Non-grouped Conv uses gemmStridedBatched (legacy cuBLAS, no cuBLASLt layout + // registration). Grouped Conv still uses the per-group matmul path below. + if (fAttrGroup == 1) return ""; size_t oDepth_ = (fDim > 2) ? fShapeY[2].dim : 1; size_t oHeight_ = (fDim > 1) ? fShapeY[fDim].dim : 1; size_t oWidth_ = fShapeY[fDim + 1].dim; diff --git a/core/test/ConvBatchModelGenerator.py b/core/test/ConvBatchModelGenerator.py new file mode 100644 index 0000000..1082790 --- /dev/null +++ b/core/test/ConvBatchModelGenerator.py @@ -0,0 +1,92 @@ +#!/usr/bin/python3 +# +# ConvBatchModelGenerator.py +# +# Generates a batch>1 Conv ONNX model and its reference output for the SOFIE +# alpaka GPU test oracle (ConvBatch4). The reference is computed with +# onnxruntime so it reflects the exact ONNX Conv semantics, independent of +# anything SOFIE does. This model is the correctness oracle for the +# strided-batched Conv GEMM work (Conv batch>1 path). +# +# Usage: python3 ConvBatchModelGenerator.py +# Needs: pip install onnx numpy +# Writes: input_models/ConvBatch4.onnx +# input_models/references/ConvBatch4.ref.hxx (expected output) +# input_models/references/ConvBatch4_input.ref.hxx (input data) + +import os +import numpy as np +import onnx +from onnx import helper, TensorProto, numpy_helper + + +def conv2d_ref(X, W, Bs, pad, stride): + # Plain cross-correlation, matching ONNX Conv semantics for the + # symmetric-pad, unit-stride, group=1, no-dilation case used here. + Bn, Cin, H, Wd = X.shape + Cout, _, K, _ = W.shape + Xp = np.pad(X, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode="constant") + oH = (H + 2 * pad - K) // stride + 1 + oW = (Wd + 2 * pad - K) // stride + 1 + Y = np.empty((Bn, Cout, oH, oW), np.float32) + for b in range(Bn): + for co in range(Cout): + for i in range(oH): + for j in range(oW): + patch = Xp[b, :, i * stride:i * stride + K, j * stride:j * stride + K] + Y[b, co, i, j] = np.float32(np.sum(patch * W[co]) + Bs[co]) + return Y + +NAME = "ConvBatch4" +B, Cin, Cout, H, W = 4, 2, 3, 5, 5 # batch=4 is the point of this oracle +K, PAD, STRIDE = 3, 1, 1 + +OH = (H + 2 * PAD - K) // STRIDE + 1 +OW = (W + 2 * PAD - K) // STRIDE + 1 + +np.random.seed(42) +X = np.random.randn(B, Cin, H, W).astype(np.float32) +Wt = (np.random.randn(Cout, Cin, K, K) * 0.2).astype(np.float32) +Bs = (np.random.randn(Cout) * 0.1).astype(np.float32) + +node = helper.make_node( + "Conv", ["input", "W", "B"], ["output"], + kernel_shape=[K, K], pads=[PAD, PAD, PAD, PAD], + strides=[STRIDE, STRIDE], dilations=[1, 1], group=1) + +graph = helper.make_graph( + [node], NAME, + [helper.make_tensor_value_info("input", TensorProto.FLOAT, [B, Cin, H, W])], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [B, Cout, OH, OW])], + [numpy_helper.from_array(Wt, "W"), numpy_helper.from_array(Bs, "B")]) + +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) +model.ir_version = 8 +onnx.checker.check_model(model) + +# Reference output computed independently of SOFIE. +Y = conv2d_ref(X, Wt, Bs, PAD, STRIDE) + +here = os.path.dirname(os.path.abspath(__file__)) +imdir = os.path.join(here, "input_models") +refdir = os.path.join(imdir, "references") +os.makedirs(refdir, exist_ok=True) +onnx.save(model, os.path.join(imdir, NAME + ".onnx")) + + +def emit(path, ns, decl, arr): + body = ", ".join("{:.8f}f".format(v) for v in arr.reshape(-1)) + with open(path, "w") as f: + f.write("// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT\n") + f.write("#pragma once\n") + f.write("namespace {} {{\n".format(ns)) + f.write(" {} = {{{}}};\n".format(decl, body)) + f.write("}} // namespace {}\n".format(ns)) + + +emit(os.path.join(refdir, NAME + "_input.ref.hxx"), + NAME + "_Input", "static float data[{}]".format(X.size), X) +emit(os.path.join(refdir, NAME + ".ref.hxx"), + NAME + "_ExpectedOutput", "float output[]", Y) + +print("wrote {}.onnx and references; input {} output {}".format(NAME, X.shape, Y.shape)) diff --git a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx index 5ad9383..893026f 100644 --- a/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx +++ b/core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx @@ -164,6 +164,10 @@ #include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx" #include "input_models/references/ConvWithAsymmetricPadding.ref.hxx" +#include "ConvBatch4_FromONNX_GPU_ALPAKA.hxx" +#include "input_models/references/ConvBatch4.ref.hxx" +#include "input_models/references/ConvBatch4_input.ref.hxx" + #include "BatchNorm_FromONNX_GPU_ALPAKA.hxx" #include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx" @@ -2442,6 +2446,39 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding) } } +TEST_F(SofieAlpakaTest, ConvBatch4) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Batch=4 input from the generated reference header + constexpr size_t N = sizeof(ConvBatch4_Input::data) / sizeof(float); + auto input_h = alpaka::allocBuf(host, Ext1D::all(Idx{N})); + float* input_ptr = reinterpret_cast(alpaka::getPtrNative(input_h)); + for (Idx i = 0; i < N; ++i) input_ptr[i] = ConvBatch4_Input::data[i]; + + auto input_d = alpaka::allocBuf(device, Ext1D::all(Idx{N})); + alpaka::memcpy(queue, input_d, input_h); + alpaka::wait(queue); + + constexpr size_t nOut = sizeof(ConvBatch4_ExpectedOutput::output) / sizeof(float); + auto result_h = alpaka::allocBuf(host, Ext1D::all(Idx{nOut})); + + { + SOFIE_ConvBatch4::Session session("ConvBatch4_FromONNX_GPU_ALPAKA.dat"); + auto result = session.infer(input_d); + alpaka::wait(queue); + cudaDeviceSynchronize(); + alpaka::memcpy(queue, result_h, result); + alpaka::wait(queue); + } + + float* res_ptr = reinterpret_cast(alpaka::getPtrNative(result_h)); + float* correct = ConvBatch4_ExpectedOutput::output; + for (size_t i = 0; i < nOut; ++i) { + EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i; + } +} + TEST_F(SofieAlpakaTest, BatchNormalization) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; diff --git a/core/test/input_models/ConvBatch4.onnx b/core/test/input_models/ConvBatch4.onnx new file mode 100644 index 0000000..6d8fded Binary files /dev/null and b/core/test/input_models/ConvBatch4.onnx differ diff --git a/core/test/input_models/references/ConvBatch4.ref.hxx b/core/test/input_models/references/ConvBatch4.ref.hxx new file mode 100644 index 0000000..5f83bf7 --- /dev/null +++ b/core/test/input_models/references/ConvBatch4.ref.hxx @@ -0,0 +1,5 @@ +// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT +#pragma once +namespace ConvBatch4_ExpectedOutput { + float output[] = {-0.03195293f, -0.44289374f, -0.25148013f, -0.17015593f, -0.19932672f, -0.71857929f, -0.64873677f, 0.06540573f, -0.28458852f, -1.51479936f, 0.27430636f, -0.21598786f, 1.08586359f, 0.25843918f, -1.05735528f, 0.29397851f, -1.22813070f, -1.44281471f, -1.73121357f, -1.41360247f, -0.73737752f, 0.41905475f, -0.10619140f, -0.12995683f, -1.09054959f, 0.00480272f, 0.72242731f, -0.58944398f, -0.60447931f, 0.07878304f, -1.43119335f, 0.56298935f, -0.78426296f, -1.54268718f, 0.47342026f, 1.25520730f, 0.21803036f, -0.20851079f, 0.27346700f, -0.01995897f, 0.16370720f, -0.22034909f, 0.42588338f, -0.59627855f, -0.29241365f, -0.37023157f, -0.56963795f, 0.24267963f, -1.08134496f, 0.62385619f, 0.05782463f, 0.80811042f, 0.26196903f, -0.25693661f, 0.38923186f, -0.65506494f, -1.20507371f, 0.58957851f, -0.85846341f, -1.55607462f, 0.15124553f, -0.00898220f, -0.14019980f, -0.98541522f, -1.19516015f, 0.29602796f, 0.92750454f, 1.61471677f, 0.82137787f, -0.17293184f, -0.02203251f, -0.04945083f, 0.52750534f, 1.47753346f, 0.42430305f, -0.04113004f, 0.15575086f, -0.42321450f, -0.13088074f, -0.36786279f, -0.26988190f, 0.23808815f, -0.07239839f, -0.69945872f, -0.75143027f, -0.05842094f, 0.25516850f, 1.29627681f, 0.91249049f, -1.03885472f, -0.09614705f, -0.43640488f, 0.14894338f, 0.13826089f, -0.42662841f, -0.38174152f, -0.41764873f, 0.75635111f, -0.72875881f, 0.52634460f, 0.29975626f, -0.02850732f, 0.31794560f, -0.41813749f, -0.33273274f, -0.14072102f, -0.44187945f, 0.18443757f, -0.46421683f, 0.80352211f, 0.09791172f, 0.71054292f, -0.23731494f, -0.11801884f, 0.52055764f, -0.21512429f, -0.66273683f, -1.10437763f, 0.70684701f, -0.89043939f, -0.44892120f, 0.34505716f, -1.12635148f, 0.73940432f, -0.92215848f, 0.47199863f, 1.07056689f, 0.50565445f, -0.02993017f, -0.29694647f, -0.19824053f, 0.47338027f, 0.73690319f, 0.54190832f, -0.10611010f, 0.54877001f, -0.49843562f, 0.86112380f, 0.08716756f, -0.63958883f, -1.25493753f, 1.09565914f, 0.42359996f, 1.23477066f, 0.25041544f, -0.23478310f, -1.30485690f, 0.54651511f, -0.09381580f, 0.68012393f, 1.13701499f, 0.11230235f, 0.12619297f, -0.19030818f, -0.09028511f, -1.45047796f, 0.56167859f, -0.90403801f, -0.82075167f, 0.38493961f, 0.85577857f, -1.68061662f, -0.19721150f, -1.81811094f, 0.86795306f, -0.57781249f, 1.86487007f, -0.04269446f, -1.58340085f, 0.54105806f, 0.32966709f, -0.20251061f, 0.07146855f, -0.73388404f, -0.13275960f, -0.56517655f, 0.63664508f, -0.09961121f, 0.36751759f, -0.30378032f, -1.18669665f, -0.61507332f, -0.77732700f, 0.86093783f, -1.04621005f, 1.19616795f, 0.91035879f, -1.64908421f, 1.07979822f, -1.02034903f, -0.00696687f, -0.88661492f, 0.68053216f, 0.28206486f, -0.30980539f, 0.51003075f, -2.12627864f, 1.55075324f, 0.09984571f, 0.31666389f, -0.23893179f, 0.64651191f, -0.47463280f, 0.03045150f, -0.28028804f, -0.02903669f, -0.08425365f, -0.09182890f, 0.74768209f, 0.47763556f, 1.08524728f, 1.05665338f, -2.09318995f, -0.15524380f, 1.10305774f, 0.08724110f, 0.51449955f, 0.03605954f, -0.89257431f, -0.47188252f, 0.22086650f, 0.98427254f, -1.03351998f, 0.54820883f, -0.46885133f, -0.11631396f, 0.44594568f, 0.40414166f, -0.57540429f, 0.17267258f, -0.12081935f, 0.18675523f, 0.35419565f, 2.09992743f, -0.24477072f, 0.38851243f, -0.07517084f, -0.35150605f, -0.35997570f, -0.93321586f, -0.47499138f, 0.22723971f, -0.23042275f, 0.15123741f, -0.73816180f, 0.71998096f, 0.18345900f, 0.10092045f, -0.19809513f, -1.33629334f, -0.06184539f, -0.08708560f, -0.66197276f, -0.03373551f, 0.15822518f, -0.51707822f, 0.24624628f, -0.42941380f, -0.94697809f, -0.19533426f, 0.26275522f, 0.74389839f, -0.16533290f, -1.35670066f, 0.59837604f, -0.10891475f, -0.68936580f, -0.53408009f, -0.43481499f, 0.15429562f, 0.66884530f, 0.95782363f, -0.27496409f, -0.47386920f, 0.05938534f, -0.34027457f, -0.27625149f, 0.52149832f, -0.58093745f, 0.23379827f, 0.39020044f, -0.04858817f, -0.05660284f, -1.02562809f, -1.06651473f, -0.15794961f, 0.64807868f, 0.44769567f, 0.32802519f, -0.40214491f, -0.45705903f, -0.41672075f, 0.26201040f, -0.10507706f, -0.04469840f, 0.37760037f, -0.33522218f, -0.46691108f, -0.19591658f, 0.20714454f}; +} // namespace ConvBatch4_ExpectedOutput diff --git a/core/test/input_models/references/ConvBatch4_input.ref.hxx b/core/test/input_models/references/ConvBatch4_input.ref.hxx new file mode 100644 index 0000000..1711423 --- /dev/null +++ b/core/test/input_models/references/ConvBatch4_input.ref.hxx @@ -0,0 +1,5 @@ +// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT +#pragma once +namespace ConvBatch4_Input { + static float data[200] = {0.49671414f, -0.13826430f, 0.64768857f, 1.52302980f, -0.23415338f, -0.23413695f, 1.57921278f, 0.76743472f, -0.46947438f, 0.54256004f, -0.46341768f, -0.46572974f, 0.24196227f, -1.91328025f, -1.72491789f, -0.56228751f, -1.01283109f, 0.31424734f, -0.90802407f, -1.41230369f, 1.46564877f, -0.22577630f, 0.06752820f, -1.42474818f, -0.54438275f, 0.11092259f, -1.15099359f, 0.37569803f, -0.60063869f, -0.29169375f, -0.60170662f, 1.85227823f, -0.01349723f, -1.05771089f, 0.82254493f, -1.22084367f, 0.20886360f, -1.95967007f, -1.32818604f, 0.19686124f, 0.73846656f, 0.17136829f, -0.11564828f, -0.30110368f, -1.47852194f, -0.71984422f, -0.46063876f, 1.05712223f, 0.34361830f, -1.76304018f, 0.32408398f, -0.38508227f, -0.67692202f, 0.61167628f, 1.03099954f, 0.93128014f, -0.83921754f, -0.30921239f, 0.33126342f, 0.97554511f, -0.47917423f, -0.18565898f, -1.10633492f, -1.19620657f, 0.81252581f, 1.35624003f, -0.07201012f, 1.00353289f, 0.36163601f, -0.64511973f, 0.36139560f, 1.53803658f, -0.03582604f, 1.56464362f, -2.61974502f, 0.82190251f, 0.08704707f, -0.29900736f, 0.09176078f, -1.98756886f, -0.21967189f, 0.35711256f, 1.47789407f, -0.51827019f, -0.80849361f, -0.50175703f, 0.91540211f, 0.32875112f, -0.52976018f, 0.51326746f, 0.09707755f, 0.96864498f, -0.70205307f, -0.32766214f, -0.39210814f, -1.46351492f, 0.29612029f, 0.26105526f, 0.00511346f, -0.23458713f, -1.41537070f, -0.42064533f, -0.34271452f, -0.80227727f, -0.16128571f, 0.40405086f, 1.88618588f, 0.17457782f, 0.25755039f, -0.07444592f, -1.91877127f, -0.02651387f, 0.06023021f, 2.46324205f, -0.19236097f, 0.30154735f, -0.03471177f, -1.16867805f, 1.14282286f, 0.75193304f, 0.79103196f, -0.90938747f, 1.40279436f, -1.40185106f, 0.58685708f, 2.19045568f, -0.99053633f, -0.56629771f, 0.09965137f, -0.50347567f, -1.55066347f, 0.06856298f, -1.06230366f, 0.47359243f, -0.91942424f, 1.54993439f, -0.78325331f, -0.32206151f, 0.81351721f, -1.23086429f, 0.22745994f, 1.30714273f, -1.60748327f, 0.18463387f, 0.25988281f, 0.78182286f, -1.23695076f, -1.32045662f, 0.52194154f, 0.29698467f, 0.25049284f, 0.34644821f, -0.68002474f, 0.23225370f, 0.29307246f, -0.71435142f, 1.86577451f, 0.47383294f, -1.19130349f, 0.65655363f, -0.97468168f, 0.78708458f, 1.15859556f, -0.82068235f, 0.96337610f, 0.41278094f, 0.82206017f, 1.89679301f, -0.24538812f, -0.75373614f, -0.88951445f, -0.81581026f, -0.07710171f, 0.34115198f, 0.27669081f, 0.82718325f, 0.01300189f, 1.45353413f, -0.26465684f, 2.72016907f, 0.62566733f, -0.85715753f, -1.07089245f, 0.48247242f, -0.22346279f, 0.71400052f, 0.47323763f, -0.07282891f, -0.84679371f, -1.51484728f, -0.44651496f, 0.85639882f, 0.21409374f, -1.24573874f, 0.17318092f, 0.38531739f, -0.88385743f, 0.15372510f, 0.05820872f, -1.14297032f}; +} // namespace ConvBatch4_Input