Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions core/inc/SOFIE/ROperator_Conv.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ public:
}

std::vector<size_t> shape1 = {fShapeW[0], fShapeW[1], kernelSize};
std::vector<Dim> shape2 = {Dim{fShapeW[1]}, Dim{kernelSize}, channelDim };
// _xcol holds the im2col of every batch sample, so the non-grouped GPU path can
// run a single strided-batched GEMM over all samples (each gets its own slice).
std::vector<Dim> shape2 = {fShapeX[0], Dim{fShapeW[1]}, Dim{kernelSize}, channelDim };
model.AddIntermediateTensor(fNX +"_f", ConvertStringToType(fType), shape1 );
model.AddIntermediateTensor(fNX +"_xcol", ConvertStringToType(fType), shape2 );
convK = fNX +"_f";
Expand Down Expand Up @@ -874,22 +876,21 @@ public:
// Step 3 + 4: Im2Col then GEMM — structure differs for grouped vs non-grouped
// -----------------------------------------------------------------------
if (fAttrGroup == 1) {
// Non-grouped: single im2col per batch, then GEMM
out << SP << SP << "// Step 3: im2col\n";
// Non-grouped: im2col this sample into its own _xcol slice (slice n). The
// single strided-batched GEMM over all samples is issued after the loop.
out << SP << SP << "{\n";
out << SP << SP << SP << "auto const elementsPerThread_im2col = Vec::all(static_cast<Idx>(1));\n";
out << SP << SP << SP << "auto const elementsPerGrid_im2col = Vec::all(Idx{" << colElements << "});\n";
out << SP << SP << SP << "auto const workDiv_im2col = sofie_workdiv(elementsPerGrid_im2col);\n";
out << SP << SP << SP << "alpaka::exec<Acc>(queue, workDiv_im2col, im2colKernel_" << opName
<< ", alpaka::getPtrNative(deviceBuf_" << fNX << ") + x_offset"
<< ", alpaka::getPtrNative(deviceBuf_" << imcol << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << imcol << ") + n * " << colElements << "u"
<< ", static_cast<Idx>(" << colElements << "));\n";
out << SP << SP << SP << "alpaka::wait(queue);\n";
out << SP << SP << "}\n\n";

if (!fNB.empty()) {
size_t biasElements = gemm_n * gemm_m;
out << SP << SP << "// Step 4a: broadcast bias into output slice\n";
out << SP << SP << "// broadcast bias into this sample's output slice\n";
out << SP << SP << "{\n";
out << SP << SP << SP << "auto const elementsPerThread_bias = Vec::all(static_cast<Idx>(1));\n";
out << SP << SP << SP << "auto const elementsPerGrid_bias = Vec::all(Idx{" << biasElements << "});\n";
Expand All @@ -898,24 +899,8 @@ public:
<< ", alpaka::getPtrNative(deviceBuf_" << fNB << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset"
<< ", static_cast<Idx>(" << biasElements << "));\n";
out << SP << SP << SP << "alpaka::wait(queue);\n";
out << SP << SP << "}\n\n";
out << SP << SP << "// Step 4b: GEMM beta=1 accumulates onto bias-initialised output\n";
out << SP << SP << "blas.matmul('n', 'n', "
<< gemm_m << ", " << gemm_n << ", " << gemm_k
<< ", 1.0f, alpaka::getPtrNative(deviceBuf_" << imcol << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << convK << ")"
<< ", 1.0f, alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset);\n\n";
} else {
out << SP << SP << "// Step 4: GEMM beta=0 (no bias)\n";
out << SP << SP << "blas.matmul('n', 'n', "
<< gemm_m << ", " << gemm_n << ", " << gemm_k
<< ", 1.0f, alpaka::getPtrNative(deviceBuf_" << imcol << ")"
<< ", alpaka::getPtrNative(deviceBuf_" << convK << ")"
<< ", 0.0f, alpaka::getPtrNative(deviceBuf_" << fNY << ") + out_offset);\n\n";
}
// Wait for GEMM to finish before next batch overwrites the shared _xcol buffer.
out << SP << SP << "alpaka::wait(queue);\n\n";

} else {
// Grouped convolution: im2col and GEMM per group with group-adjusted input pointer.
Expand Down Expand Up @@ -970,6 +955,21 @@ public:
}

out << SP << "}\n"; // end batch loop

// Non-grouped: replace the per-sample matmul loop with one strided-batched GEMM.
// Each sample reads its own _xcol slice (strideA = colElements) and writes its own
// output block (strideC = gemm_n*gemm_m); the weight _f is shared, so strideB = 0.
if (fAttrGroup == 1) {
std::string convBeta = fNB.empty() ? "0.0f" : "1.0f";
out << SP << "alpaka::wait(queue);\n";
out << SP << "blas.gemmStridedBatched('n', 'n', "
<< gemm_m << ", " << gemm_n << ", " << gemm_k << ", 1.0f, "
<< "alpaka::getPtrNative(deviceBuf_" << imcol << "), " << gemm_m << ", " << colElements << ", "
<< "alpaka::getPtrNative(deviceBuf_" << convK << "), " << gemm_k << ", 0, "
<< convBeta << ", alpaka::getPtrNative(deviceBuf_" << fNY << "), "
<< gemm_m << ", " << gemm_n * gemm_m << ", " << bsize << ");\n";
out << SP << "alpaka::wait(queue);\n";
}
return out.str();
}

Expand All @@ -979,6 +979,9 @@ public:


std::string GetBlasConfig(){
// Non-grouped Conv uses gemmStridedBatched (legacy cuBLAS, no cuBLASLt layout
// registration). Grouped Conv still uses the per-group matmul path below.
if (fAttrGroup == 1) return "";
size_t oDepth_ = (fDim > 2) ? fShapeY[2].dim : 1;
size_t oHeight_ = (fDim > 1) ? fShapeY[fDim].dim : 1;
size_t oWidth_ = fShapeY[fDim + 1].dim;
Expand Down
92 changes: 92 additions & 0 deletions core/test/ConvBatchModelGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/python3
#
# ConvBatchModelGenerator.py
#
# Generates a batch>1 Conv ONNX model and its reference output for the SOFIE
# alpaka GPU test oracle (ConvBatch4). The reference is computed with
# onnxruntime so it reflects the exact ONNX Conv semantics, independent of
# anything SOFIE does. This model is the correctness oracle for the
# strided-batched Conv GEMM work (Conv batch>1 path).
#
# Usage: python3 ConvBatchModelGenerator.py
# Needs: pip install onnx numpy
# Writes: input_models/ConvBatch4.onnx
# input_models/references/ConvBatch4.ref.hxx (expected output)
# input_models/references/ConvBatch4_input.ref.hxx (input data)

import os
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper


def conv2d_ref(X, W, Bs, pad, stride):
# Plain cross-correlation, matching ONNX Conv semantics for the
# symmetric-pad, unit-stride, group=1, no-dilation case used here.
Bn, Cin, H, Wd = X.shape
Cout, _, K, _ = W.shape
Xp = np.pad(X, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode="constant")
oH = (H + 2 * pad - K) // stride + 1
oW = (Wd + 2 * pad - K) // stride + 1
Y = np.empty((Bn, Cout, oH, oW), np.float32)
for b in range(Bn):
for co in range(Cout):
for i in range(oH):
for j in range(oW):
patch = Xp[b, :, i * stride:i * stride + K, j * stride:j * stride + K]
Y[b, co, i, j] = np.float32(np.sum(patch * W[co]) + Bs[co])
return Y

NAME = "ConvBatch4"
B, Cin, Cout, H, W = 4, 2, 3, 5, 5 # batch=4 is the point of this oracle
K, PAD, STRIDE = 3, 1, 1

OH = (H + 2 * PAD - K) // STRIDE + 1
OW = (W + 2 * PAD - K) // STRIDE + 1

np.random.seed(42)
X = np.random.randn(B, Cin, H, W).astype(np.float32)
Wt = (np.random.randn(Cout, Cin, K, K) * 0.2).astype(np.float32)
Bs = (np.random.randn(Cout) * 0.1).astype(np.float32)

node = helper.make_node(
"Conv", ["input", "W", "B"], ["output"],
kernel_shape=[K, K], pads=[PAD, PAD, PAD, PAD],
strides=[STRIDE, STRIDE], dilations=[1, 1], group=1)

graph = helper.make_graph(
[node], NAME,
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [B, Cin, H, W])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [B, Cout, OH, OW])],
[numpy_helper.from_array(Wt, "W"), numpy_helper.from_array(Bs, "B")])

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 8
onnx.checker.check_model(model)

# Reference output computed independently of SOFIE.
Y = conv2d_ref(X, Wt, Bs, PAD, STRIDE)

here = os.path.dirname(os.path.abspath(__file__))
imdir = os.path.join(here, "input_models")
refdir = os.path.join(imdir, "references")
os.makedirs(refdir, exist_ok=True)
onnx.save(model, os.path.join(imdir, NAME + ".onnx"))


def emit(path, ns, decl, arr):
body = ", ".join("{:.8f}f".format(v) for v in arr.reshape(-1))
with open(path, "w") as f:
f.write("// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT\n")
f.write("#pragma once\n")
f.write("namespace {} {{\n".format(ns))
f.write(" {} = {{{}}};\n".format(decl, body))
f.write("}} // namespace {}\n".format(ns))


emit(os.path.join(refdir, NAME + "_input.ref.hxx"),
NAME + "_Input", "static float data[{}]".format(X.size), X)
emit(os.path.join(refdir, NAME + ".ref.hxx"),
NAME + "_ExpectedOutput", "float output[]", Y)

print("wrote {}.onnx and references; input {} output {}".format(NAME, X.shape, Y.shape))
37 changes: 37 additions & 0 deletions core/test/TestCustomModelsFromONNXForAlpakaCuda.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@
#include "ConvWithAsymmetricPadding_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvWithAsymmetricPadding.ref.hxx"

#include "ConvBatch4_FromONNX_GPU_ALPAKA.hxx"
#include "input_models/references/ConvBatch4.ref.hxx"
#include "input_models/references/ConvBatch4_input.ref.hxx"

#include "BatchNorm_FromONNX_GPU_ALPAKA.hxx"
#include "BatchNormRelu_FromONNX_GPU_ALPAKA.hxx"

Expand Down Expand Up @@ -2442,6 +2446,39 @@ TEST_F(SofieAlpakaTest, ConvWithAsymmetricPadding)
}
}

TEST_F(SofieAlpakaTest, ConvBatch4)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Batch=4 input from the generated reference header
constexpr size_t N = sizeof(ConvBatch4_Input::data) / sizeof(float);
auto input_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{N}));
float* input_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(input_h));
for (Idx i = 0; i < N; ++i) input_ptr[i] = ConvBatch4_Input::data[i];

auto input_d = alpaka::allocBuf<float, Idx>(device, Ext1D::all(Idx{N}));
alpaka::memcpy(queue, input_d, input_h);
alpaka::wait(queue);

constexpr size_t nOut = sizeof(ConvBatch4_ExpectedOutput::output) / sizeof(float);
auto result_h = alpaka::allocBuf<float, Idx>(host, Ext1D::all(Idx{nOut}));

{
SOFIE_ConvBatch4::Session<alpaka::TagGpuCudaRt> session("ConvBatch4_FromONNX_GPU_ALPAKA.dat");
auto result = session.infer(input_d);
alpaka::wait(queue);
cudaDeviceSynchronize();
alpaka::memcpy(queue, result_h, result);
alpaka::wait(queue);
}

float* res_ptr = reinterpret_cast<float*>(alpaka::getPtrNative(result_h));
float* correct = ConvBatch4_ExpectedOutput::output;
for (size_t i = 0; i < nOut; ++i) {
EXPECT_LE(std::abs(res_ptr[i] - correct[i]), TOLERANCE) << "i=" << i;
}
}

TEST_F(SofieAlpakaTest, BatchNormalization)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down
Binary file added core/test/input_models/ConvBatch4.onnx
Binary file not shown.
5 changes: 5 additions & 0 deletions core/test/input_models/references/ConvBatch4.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT
#pragma once
namespace ConvBatch4_ExpectedOutput {
float output[] = {-0.03195293f, -0.44289374f, -0.25148013f, -0.17015593f, -0.19932672f, -0.71857929f, -0.64873677f, 0.06540573f, -0.28458852f, -1.51479936f, 0.27430636f, -0.21598786f, 1.08586359f, 0.25843918f, -1.05735528f, 0.29397851f, -1.22813070f, -1.44281471f, -1.73121357f, -1.41360247f, -0.73737752f, 0.41905475f, -0.10619140f, -0.12995683f, -1.09054959f, 0.00480272f, 0.72242731f, -0.58944398f, -0.60447931f, 0.07878304f, -1.43119335f, 0.56298935f, -0.78426296f, -1.54268718f, 0.47342026f, 1.25520730f, 0.21803036f, -0.20851079f, 0.27346700f, -0.01995897f, 0.16370720f, -0.22034909f, 0.42588338f, -0.59627855f, -0.29241365f, -0.37023157f, -0.56963795f, 0.24267963f, -1.08134496f, 0.62385619f, 0.05782463f, 0.80811042f, 0.26196903f, -0.25693661f, 0.38923186f, -0.65506494f, -1.20507371f, 0.58957851f, -0.85846341f, -1.55607462f, 0.15124553f, -0.00898220f, -0.14019980f, -0.98541522f, -1.19516015f, 0.29602796f, 0.92750454f, 1.61471677f, 0.82137787f, -0.17293184f, -0.02203251f, -0.04945083f, 0.52750534f, 1.47753346f, 0.42430305f, -0.04113004f, 0.15575086f, -0.42321450f, -0.13088074f, -0.36786279f, -0.26988190f, 0.23808815f, -0.07239839f, -0.69945872f, -0.75143027f, -0.05842094f, 0.25516850f, 1.29627681f, 0.91249049f, -1.03885472f, -0.09614705f, -0.43640488f, 0.14894338f, 0.13826089f, -0.42662841f, -0.38174152f, -0.41764873f, 0.75635111f, -0.72875881f, 0.52634460f, 0.29975626f, -0.02850732f, 0.31794560f, -0.41813749f, -0.33273274f, -0.14072102f, -0.44187945f, 0.18443757f, -0.46421683f, 0.80352211f, 0.09791172f, 0.71054292f, -0.23731494f, -0.11801884f, 0.52055764f, -0.21512429f, -0.66273683f, -1.10437763f, 0.70684701f, -0.89043939f, -0.44892120f, 0.34505716f, -1.12635148f, 0.73940432f, -0.92215848f, 0.47199863f, 1.07056689f, 0.50565445f, -0.02993017f, -0.29694647f, -0.19824053f, 0.47338027f, 0.73690319f, 0.54190832f, -0.10611010f, 0.54877001f, -0.49843562f, 0.86112380f, 0.08716756f, -0.63958883f, -1.25493753f, 1.09565914f, 0.42359996f, 1.23477066f, 0.25041544f, -0.23478310f, -1.30485690f, 0.54651511f, -0.09381580f, 0.68012393f, 1.13701499f, 0.11230235f, 0.12619297f, -0.19030818f, -0.09028511f, -1.45047796f, 0.56167859f, -0.90403801f, -0.82075167f, 0.38493961f, 0.85577857f, -1.68061662f, -0.19721150f, -1.81811094f, 0.86795306f, -0.57781249f, 1.86487007f, -0.04269446f, -1.58340085f, 0.54105806f, 0.32966709f, -0.20251061f, 0.07146855f, -0.73388404f, -0.13275960f, -0.56517655f, 0.63664508f, -0.09961121f, 0.36751759f, -0.30378032f, -1.18669665f, -0.61507332f, -0.77732700f, 0.86093783f, -1.04621005f, 1.19616795f, 0.91035879f, -1.64908421f, 1.07979822f, -1.02034903f, -0.00696687f, -0.88661492f, 0.68053216f, 0.28206486f, -0.30980539f, 0.51003075f, -2.12627864f, 1.55075324f, 0.09984571f, 0.31666389f, -0.23893179f, 0.64651191f, -0.47463280f, 0.03045150f, -0.28028804f, -0.02903669f, -0.08425365f, -0.09182890f, 0.74768209f, 0.47763556f, 1.08524728f, 1.05665338f, -2.09318995f, -0.15524380f, 1.10305774f, 0.08724110f, 0.51449955f, 0.03605954f, -0.89257431f, -0.47188252f, 0.22086650f, 0.98427254f, -1.03351998f, 0.54820883f, -0.46885133f, -0.11631396f, 0.44594568f, 0.40414166f, -0.57540429f, 0.17267258f, -0.12081935f, 0.18675523f, 0.35419565f, 2.09992743f, -0.24477072f, 0.38851243f, -0.07517084f, -0.35150605f, -0.35997570f, -0.93321586f, -0.47499138f, 0.22723971f, -0.23042275f, 0.15123741f, -0.73816180f, 0.71998096f, 0.18345900f, 0.10092045f, -0.19809513f, -1.33629334f, -0.06184539f, -0.08708560f, -0.66197276f, -0.03373551f, 0.15822518f, -0.51707822f, 0.24624628f, -0.42941380f, -0.94697809f, -0.19533426f, 0.26275522f, 0.74389839f, -0.16533290f, -1.35670066f, 0.59837604f, -0.10891475f, -0.68936580f, -0.53408009f, -0.43481499f, 0.15429562f, 0.66884530f, 0.95782363f, -0.27496409f, -0.47386920f, 0.05938534f, -0.34027457f, -0.27625149f, 0.52149832f, -0.58093745f, 0.23379827f, 0.39020044f, -0.04858817f, -0.05660284f, -1.02562809f, -1.06651473f, -0.15794961f, 0.64807868f, 0.44769567f, 0.32802519f, -0.40214491f, -0.45705903f, -0.41672075f, 0.26201040f, -0.10507706f, -0.04469840f, 0.37760037f, -0.33522218f, -0.46691108f, -0.19591658f, 0.20714454f};
} // namespace ConvBatch4_ExpectedOutput
5 changes: 5 additions & 0 deletions core/test/input_models/references/ConvBatch4_input.ref.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Auto-generated by ConvBatchModelGenerator.py - DO NOT EDIT
#pragma once
namespace ConvBatch4_Input {
static float data[200] = {0.49671414f, -0.13826430f, 0.64768857f, 1.52302980f, -0.23415338f, -0.23413695f, 1.57921278f, 0.76743472f, -0.46947438f, 0.54256004f, -0.46341768f, -0.46572974f, 0.24196227f, -1.91328025f, -1.72491789f, -0.56228751f, -1.01283109f, 0.31424734f, -0.90802407f, -1.41230369f, 1.46564877f, -0.22577630f, 0.06752820f, -1.42474818f, -0.54438275f, 0.11092259f, -1.15099359f, 0.37569803f, -0.60063869f, -0.29169375f, -0.60170662f, 1.85227823f, -0.01349723f, -1.05771089f, 0.82254493f, -1.22084367f, 0.20886360f, -1.95967007f, -1.32818604f, 0.19686124f, 0.73846656f, 0.17136829f, -0.11564828f, -0.30110368f, -1.47852194f, -0.71984422f, -0.46063876f, 1.05712223f, 0.34361830f, -1.76304018f, 0.32408398f, -0.38508227f, -0.67692202f, 0.61167628f, 1.03099954f, 0.93128014f, -0.83921754f, -0.30921239f, 0.33126342f, 0.97554511f, -0.47917423f, -0.18565898f, -1.10633492f, -1.19620657f, 0.81252581f, 1.35624003f, -0.07201012f, 1.00353289f, 0.36163601f, -0.64511973f, 0.36139560f, 1.53803658f, -0.03582604f, 1.56464362f, -2.61974502f, 0.82190251f, 0.08704707f, -0.29900736f, 0.09176078f, -1.98756886f, -0.21967189f, 0.35711256f, 1.47789407f, -0.51827019f, -0.80849361f, -0.50175703f, 0.91540211f, 0.32875112f, -0.52976018f, 0.51326746f, 0.09707755f, 0.96864498f, -0.70205307f, -0.32766214f, -0.39210814f, -1.46351492f, 0.29612029f, 0.26105526f, 0.00511346f, -0.23458713f, -1.41537070f, -0.42064533f, -0.34271452f, -0.80227727f, -0.16128571f, 0.40405086f, 1.88618588f, 0.17457782f, 0.25755039f, -0.07444592f, -1.91877127f, -0.02651387f, 0.06023021f, 2.46324205f, -0.19236097f, 0.30154735f, -0.03471177f, -1.16867805f, 1.14282286f, 0.75193304f, 0.79103196f, -0.90938747f, 1.40279436f, -1.40185106f, 0.58685708f, 2.19045568f, -0.99053633f, -0.56629771f, 0.09965137f, -0.50347567f, -1.55066347f, 0.06856298f, -1.06230366f, 0.47359243f, -0.91942424f, 1.54993439f, -0.78325331f, -0.32206151f, 0.81351721f, -1.23086429f, 0.22745994f, 1.30714273f, -1.60748327f, 0.18463387f, 0.25988281f, 0.78182286f, -1.23695076f, -1.32045662f, 0.52194154f, 0.29698467f, 0.25049284f, 0.34644821f, -0.68002474f, 0.23225370f, 0.29307246f, -0.71435142f, 1.86577451f, 0.47383294f, -1.19130349f, 0.65655363f, -0.97468168f, 0.78708458f, 1.15859556f, -0.82068235f, 0.96337610f, 0.41278094f, 0.82206017f, 1.89679301f, -0.24538812f, -0.75373614f, -0.88951445f, -0.81581026f, -0.07710171f, 0.34115198f, 0.27669081f, 0.82718325f, 0.01300189f, 1.45353413f, -0.26465684f, 2.72016907f, 0.62566733f, -0.85715753f, -1.07089245f, 0.48247242f, -0.22346279f, 0.71400052f, 0.47323763f, -0.07282891f, -0.84679371f, -1.51484728f, -0.44651496f, 0.85639882f, 0.21409374f, -1.24573874f, 0.17318092f, 0.38531739f, -0.88385743f, 0.15372510f, 0.05820872f, -1.14297032f};
} // namespace ConvBatch4_Input