Skip to content

Commit 1c20647

Browse files
gbaraldiclaude
andcommitted
GCN: rewrite byref kernel params to addrspace(4) for s_load
Add finish_ir! for GCN that rewrites byref kernel parameters from flat (addrspace 0) to kernarg (addrspace 4) after optimization. Clang emits byref params as ptr addrspace(4) from the frontend, but Julia's RemoveJuliaAddrspacesPass strips them to flat. This causes struct field loads to use flat_load instead of s_load. The pass creates a new function with ptr addrspace(4) parameters, inserts addrspacecasts back to flat for the cloned IR, then runs InferAddressSpaces to propagate addrspace(4) through all GEPs and loads. The result is that all kernel argument struct field accesses become s_load (scalar, cached, one per wavefront) instead of flat_load (per-lane, address disambiguation overhead). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 047e172 commit 1c20647

1 file changed

Lines changed: 115 additions & 4 deletions

File tree

src/gcn.jl

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,126 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
4949
if job.config.kernel
5050
# calling convention
5151
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
52-
53-
# with byref, the AMDGPU backend's AMDGPULowerKernelArguments pass
54-
# will handle loading from the kernarg segment directly.
55-
# no need for lower_byval or manual rewriting.
5652
end
5753

5854
return entry
5955
end
6056

57+
function finish_ir!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module,
58+
entry::LLVM.Function)
59+
if job.config.kernel
60+
entry = add_kernarg_address_spaces!(job, mod, entry)
61+
end
62+
return entry
63+
end
64+
65+
# Rewrite byref kernel parameters from flat (addrspace 0) to kernarg (addrspace 4).
66+
#
67+
# On AMDGPU, the kernarg segment is in address space 4 and is scalar-loadable via s_load.
68+
# Clang emits byref parameters as `ptr addrspace(4)` from the frontend, but Julia's
69+
# RemoveJuliaAddrspacesPass strips all address spaces to flat. This pass restores the
70+
# correct address space so that struct field loads from byref arguments become s_load
71+
# instead of flat_load.
72+
#
73+
# NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
74+
# converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
75+
function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
76+
f::LLVM.Function)
77+
ft = function_type(f)
78+
79+
# find the byref parameters
80+
byref_mask = BitVector(undef, length(parameters(ft)))
81+
args = classify_arguments(job, ft; post_optimization=job.config.optimize)
82+
filter!(args) do arg
83+
arg.cc != GHOST
84+
end
85+
for arg in args
86+
byref_mask[arg.idx] = (arg.cc == BITS_REF || arg.cc == KERNEL_STATE)
87+
end
88+
89+
# check if any flat pointer byref params need rewriting
90+
needs_rewrite = false
91+
for (i, param) in enumerate(parameters(ft))
92+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
93+
needs_rewrite = true
94+
break
95+
end
96+
end
97+
needs_rewrite || return f
98+
99+
# generate the new function type with kernarg address space on byref params
100+
new_types = LLVMType[]
101+
for (i, param) in enumerate(parameters(ft))
102+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
103+
push!(new_types, LLVM.PointerType(#=kernarg=# 4))
104+
else
105+
push!(new_types, param)
106+
end
107+
end
108+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
109+
new_f = LLVM.Function(mod, "", new_ft)
110+
linkage!(new_f, linkage(f))
111+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
112+
LLVM.name!(new_arg, LLVM.name(arg))
113+
end
114+
115+
# insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
116+
# (which expects flat pointers) continues to work. InferAddressSpaces will then
117+
# propagate addrspace(4) through GEPs and loads, eliminating the casts.
118+
new_args = LLVM.Value[]
119+
@dispose builder=IRBuilder() begin
120+
entry_bb = BasicBlock(new_f, "conversion")
121+
position!(builder, entry_bb)
122+
123+
for (i, param) in enumerate(parameters(ft))
124+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
125+
cast = addrspacecast!(builder, parameters(new_f)[i], LLVM.PointerType(0))
126+
push!(new_args, cast)
127+
else
128+
push!(new_args, parameters(new_f)[i])
129+
end
130+
for attr in collect(parameter_attributes(f, i))
131+
push!(parameter_attributes(new_f, i), attr)
132+
end
133+
end
134+
135+
# clone the original function body
136+
value_map = Dict{LLVM.Value, LLVM.Value}(
137+
param => new_args[i] for (i, param) in enumerate(parameters(f))
138+
)
139+
value_map[f] = new_f
140+
clone_into!(new_f, f; value_map,
141+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
142+
143+
# fall through from conversion block to cloned entry
144+
br!(builder, blocks(new_f)[2])
145+
end
146+
147+
# replace the old function
148+
fn = LLVM.name(f)
149+
prune_constexpr_uses!(f)
150+
@assert isempty(uses(f))
151+
replace_metadata_uses!(f, new_f)
152+
erase!(f)
153+
LLVM.name!(new_f, fn)
154+
155+
# propagate addrspace(4) through GEPs and loads, then clean up
156+
@dispose pb=NewPMPassBuilder() begin
157+
add!(pb, NewPMFunctionPassManager()) do fpm
158+
add!(fpm, InferAddressSpacesPass())
159+
end
160+
add!(pb, NewPMFunctionPassManager()) do fpm
161+
add!(fpm, SimplifyCFGPass())
162+
add!(fpm, SROAPass())
163+
add!(fpm, EarlyCSEPass())
164+
add!(fpm, InstCombinePass())
165+
end
166+
run!(pb, mod)
167+
end
168+
169+
return functions(mod)[fn]
170+
end
171+
61172

62173
## LLVM passes
63174

0 commit comments

Comments
 (0)