Skip to content

Commit d0fbeb8

Browse files
gbaraldiclaude
andcommitted
GCN: use byref instead of byval+lower_byval for kernel arguments
On AMDGPU, kernel arguments already reside in the read-only kernarg segment. The current pipeline adds `byval` attributes and then `lower_byval` expands them into first-class aggregates (FCAs), which forces LLVM to extractvalue every field and store the entire struct into scratch memory via alloca — even when only a few fields are used. For large structs (e.g. Oceananigans' ImmersedBoundaryGrid), this produces dozens of dead scratch stores. Using `byref` instead keeps the pointer semantics, allowing LLVM to generate scalar loads directly from the kernarg segment on demand. The invariant.load and TBAA metadata that Julia emits remain valid since the kernarg memory is immutable. The byref pointer parameters are rewritten to addrspace(4) (AMDGPU constant/kernarg address space), with addrspacecasts inserted so the function body can continue using generic pointers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c1b651a commit d0fbeb8

3 files changed

Lines changed: 97 additions & 3 deletions

File tree

src/gcn.jl

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,90 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev
4040
const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free")
4141
isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics)
4242

43+
pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true
44+
45+
# AMDGPU constant/kernarg address space
46+
const GCN_ADDRSPACE_CONSTANT = 4
47+
48+
# Rewrite byref pointer parameters from addrspace 0 to addrspace 4 (kernarg),
49+
# inserting addrspacecasts so the function body can continue using generic pointers.
50+
function rewrite_byref_addrspaces!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
51+
mod::LLVM.Module, f::LLVM.Function)
52+
ft = function_type(f)
53+
54+
# find byref parameters
55+
byref = BitVector(undef, length(parameters(ft)))
56+
for i in 1:length(byref)
57+
byref[i] = false
58+
for attr in collect(parameter_attributes(f, i))
59+
if kind(attr) == kind(TypeAttribute("byref", LLVM.VoidType()))
60+
byref[i] = true
61+
end
62+
end
63+
end
64+
any(byref) || return f
65+
66+
# build new function type with addrspace(4) pointers for byref params
67+
new_types = LLVMType[]
68+
for (i, param) in enumerate(parameters(ft))
69+
if byref[i]
70+
push!(new_types, LLVM.PointerType(GCN_ADDRSPACE_CONSTANT))
71+
else
72+
push!(new_types, param)
73+
end
74+
end
75+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
76+
new_f = LLVM.Function(mod, "", new_ft)
77+
linkage!(new_f, linkage(f))
78+
callconv!(new_f, callconv(f))
79+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
80+
LLVM.name!(new_arg, LLVM.name(arg))
81+
end
82+
83+
# copy parameter attributes
84+
for (i, _) in enumerate(parameters(ft))
85+
for attr in collect(parameter_attributes(f, i))
86+
push!(parameter_attributes(new_f, i), attr)
87+
end
88+
end
89+
90+
# insert addrspacecasts in entry block
91+
new_args = LLVM.Value[]
92+
@dispose builder=IRBuilder() begin
93+
entry = BasicBlock(new_f, "conversion")
94+
position!(builder, entry)
95+
96+
for (i, param) in enumerate(parameters(ft))
97+
if byref[i]
98+
# cast from addrspace(4) to addrspace(0) for the function body
99+
ptr = addrspacecast!(builder, parameters(new_f)[i], param)
100+
push!(new_args, ptr)
101+
else
102+
push!(new_args, parameters(new_f)[i])
103+
end
104+
end
105+
106+
value_map = Dict{LLVM.Value, LLVM.Value}(
107+
param => new_args[i] for (i, param) in enumerate(parameters(f))
108+
)
109+
value_map[f] = new_f
110+
clone_into!(new_f, f; value_map,
111+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
112+
113+
br!(builder, blocks(new_f)[2])
114+
end
115+
116+
# replace old function
117+
fn = LLVM.name(f)
118+
prune_constexpr_uses!(f)
119+
@assert isempty(uses(f))
120+
replace_metadata_uses!(f, new_f)
121+
erase!(f)
122+
LLVM.name!(new_f, fn)
123+
124+
return new_f
125+
end
126+
43127
function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
44128
mod::LLVM.Module, entry::LLVM.Function)
45129
lower_throw_extra!(mod)
@@ -48,8 +132,8 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
48132
# calling convention
49133
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
50134

51-
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
52-
entry = lower_byval(job, mod, entry)
135+
# rewrite byref parameters to use the kernarg address space
136+
entry = rewrite_byref_addrspaces!(job, mod, entry)
53137
end
54138

55139
return entry

src/interface.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
272272
# Does the target need to pass kernel arguments by value?
273273
pass_by_value(@nospecialize(job::CompilerJob)) = true
274274

275+
# Should the target use byref instead of byval+lower_byval for kernel arguments?
276+
# When true, aggregate arguments are passed as pointers with the byref attribute,
277+
# allowing the backend to load fields directly from the argument memory (e.g. kernarg
278+
# segment on AMDGPU) instead of materializing the entire struct via first-class aggregates.
279+
pass_by_ref(@nospecialize(job::CompilerJob)) = false
280+
275281
# whether pointer is a valid call target
276282
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false
277283

src/irgen.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ function irgen(@nospecialize(job::CompilerJob))
9494
for arg in args
9595
if arg.cc == BITS_REF
9696
llvm_typ = convert(LLVMType, arg.typ)
97-
attr = TypeAttribute("byval", llvm_typ)
97+
if pass_by_ref(job)
98+
attr = TypeAttribute("byref", llvm_typ)
99+
else
100+
attr = TypeAttribute("byval", llvm_typ)
101+
end
98102
push!(parameter_attributes(entry, arg.idx), attr)
99103
end
100104
end

0 commit comments

Comments
 (0)