Skip to content

Commit f986368

Browse files
gbaraldiclaude
andcommitted
Address review feedback: fix comments, add tests, Runic formatting
- Fix addrspace(4) comment: it's the "constant" address space, not "kernarg" - Rewrite doc comment to accurately describe the Julia → AS(11) → AS(0) → AS(4) flow - Add InferAddressSpaces + SROA + InstCombine + EarlyCSE after kernarg rewrite - Add tests: byref attribute preservation, mixed byref/scalar params, programmatic IR inspection via compile(:llvm), zero scratch spills - Apply Runic formatting to new code Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3caabbf commit f986368

2 files changed

Lines changed: 136 additions & 17 deletions

File tree

src/gcn.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,17 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
5454
return entry
5555
end
5656

57-
function finish_ir!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module,
58-
entry::LLVM.Function)
57+
function finish_ir!(
58+
@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module,
59+
entry::LLVM.Function
60+
)
5961
if job.config.kernel
6062
entry = add_kernarg_address_spaces!(job, mod, entry)
6163

6264
# optimize after address space rewriting: propagate addrspace(4) through
6365
# the addrspacecast chains, then clean up newly-exposed opportunities
6466
tm = llvm_machine(job.config.target)
65-
@dispose pb=NewPMPassBuilder() tm begin
67+
@dispose pb = NewPMPassBuilder() tm begin
6668
add!(pb, NewPMFunctionPassManager()) do fpm
6769
add!(fpm, InferAddressSpacesPass())
6870
add!(fpm, SROAPass())
@@ -76,23 +78,26 @@ function finish_ir!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLV
7678
return entry
7779
end
7880

79-
# Rewrite byref kernel parameters from flat (addrspace 0) to kernarg (addrspace 4).
81+
# Rewrite byref kernel parameters from flat (addrspace 0) to constant (addrspace 4).
8082
#
81-
# On AMDGPU, the kernarg segment is in address space 4 and is scalar-loadable via s_load.
82-
# Clang emits byref parameters as `ptr addrspace(4)` from the frontend, but Julia's
83-
# RemoveJuliaAddrspacesPass strips all address spaces to flat. This pass restores the
84-
# correct address space so that struct field loads from byref arguments become s_load
85-
# instead of flat_load.
83+
# On AMDGPU, kernel arguments reside in the constant address space (addrspace 4),
84+
# which is scalar-loadable via s_load. Julia initially emits byref parameters as
85+
# pointers in addrspace(11) (tracked/derived), but RemoveJuliaAddrspacesPass strips
86+
# all non-integral address spaces to flat (addrspace 0) during optimization. This pass
87+
# restores addrspace(4) on byref parameters so that the backend can emit s_load
88+
# instead of flat_load for struct field accesses.
8689
#
8790
# NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
8891
# converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
89-
function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
90-
f::LLVM.Function)
92+
function add_kernarg_address_spaces!(
93+
@nospecialize(job::CompilerJob), mod::LLVM.Module,
94+
f::LLVM.Function
95+
)
9196
ft = function_type(f)
9297

9398
# find the byref parameters
9499
byref_mask = BitVector(undef, length(parameters(ft)))
95-
args = classify_arguments(job, ft; post_optimization=job.config.optimize)
100+
args = classify_arguments(job, ft; post_optimization = job.config.optimize)
96101
filter!(args) do arg
97102
arg.cc != GHOST
98103
end
@@ -114,7 +119,7 @@ function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.
114119
new_types = LLVMType[]
115120
for (i, param) in enumerate(parameters(ft))
116121
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
117-
push!(new_types, LLVM.PointerType(#=kernarg=# 4))
122+
push!(new_types, LLVM.PointerType(#=constant=# 4))
118123
else
119124
push!(new_types, param)
120125
end
@@ -130,7 +135,7 @@ function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.
130135
# (which expects flat pointers) continues to work. The AMDGPU backend's
131136
# AMDGPULowerKernelArguments traces these casts and produces s_load.
132137
new_args = LLVM.Value[]
133-
@dispose builder=IRBuilder() begin
138+
@dispose builder = IRBuilder() begin
134139
entry_bb = BasicBlock(new_f, "conversion")
135140
position!(builder, entry_bb)
136141

@@ -148,8 +153,10 @@ function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.
148153
param => new_args[i] for (i, param) in enumerate(parameters(f))
149154
)
150155
value_map[f] = new_f
151-
clone_into!(new_f, f; value_map,
152-
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
156+
clone_into!(
157+
new_f, f; value_map,
158+
changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
159+
)
153160

154161
# fall through from conversion block to cloned entry
155162
br!(builder, blocks(new_f)[2])
@@ -174,7 +181,7 @@ function add_kernarg_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.
174181
LLVM.name!(new_f, fn)
175182

176183
# clean up the extra conversion block
177-
@dispose pb=NewPMPassBuilder() begin
184+
@dispose pb = NewPMPassBuilder() begin
178185
add!(pb, NewPMFunctionPassManager()) do fpm
179186
add!(fpm, SimplifyCFGPass())
180187
end

test/gcn.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,96 @@ end
6363
end
6464
end
6565

66+
@testset "byref attribute preserved on kernarg parameters" begin
67+
mod = @eval module $(gensym())
68+
struct LargeStruct
69+
a::Float64
70+
b::Float64
71+
c::Float64
72+
d::Float64
73+
end
74+
75+
function kernel(s::LargeStruct, out::Ptr{Float64})
76+
unsafe_store!(out, s.a + s.b + s.c + s.d)
77+
return
78+
end
79+
end
80+
81+
# the byref attribute must survive the addrspace rewrite (clone_into! can drop it)
82+
@test @filecheck begin
83+
check"CHECK: byref"
84+
check"CHECK: addrspace(4)"
85+
GCN.code_llvm(mod.kernel, Tuple{mod.LargeStruct, Ptr{Float64}};
86+
dump_module=true, kernel=true)
87+
end
88+
end
89+
90+
@testset "mixed byref and scalar kernel parameters" begin
91+
mod = @eval module $(gensym())
92+
struct Params
93+
x::Float64
94+
y::Float64
95+
end
96+
97+
function kernel(a::Float64, s::Params, out::Ptr{Float64})
98+
unsafe_store!(out, a + s.x + s.y)
99+
return
100+
end
101+
end
102+
103+
# scalar Float64 and Ptr should NOT be in addrspace(4),
104+
# only the struct byref param should be
105+
@test @filecheck begin
106+
check"CHECK: define amdgpu_kernel void"
107+
check"CHECK-SAME: double"
108+
check"CHECK-SAME: ptr addrspace(4)"
109+
check"CHECK-SAME: ptr"
110+
GCN.code_llvm(mod.kernel, Tuple{Float64, mod.Params, Ptr{Float64}};
111+
dump_module=true, kernel=true)
112+
end
113+
end
114+
115+
@testset "add_kernarg_address_spaces! rewrites IR correctly" begin
116+
mod = @eval module $(gensym())
117+
struct KernelArgs
118+
x::Float64
119+
y::Float64
120+
z::Float64
121+
end
122+
123+
function kernel(s::KernelArgs, scale::Float64, out::Ptr{Float64})
124+
unsafe_store!(out, (s.x + s.y + s.z) * scale)
125+
return
126+
end
127+
end
128+
129+
job, _ = GCN.create_job(mod.kernel, Tuple{mod.KernelArgs, Float64, Ptr{Float64}};
130+
kernel=true)
131+
JuliaContext() do ctx
132+
ir, meta = GPUCompiler.compile(:llvm, job)
133+
134+
entry = meta.entry
135+
ft = function_type(entry)
136+
params = parameters(ft)
137+
138+
# the struct byref param should be ptr addrspace(4)
139+
has_as4 = any(p -> p isa LLVM.PointerType && addrspace(p) == 4, params)
140+
@test has_as4
141+
142+
# non-struct params (double, ptr) should NOT be in addrspace(4)
143+
non_as4_ptrs = filter(params) do p
144+
p isa LLVM.PointerType && addrspace(p) != 4
145+
end
146+
@test !isempty(non_as4_ptrs) # the Ptr{Float64} out param
147+
148+
# byref attribute must be present
149+
ir_str = string(ir)
150+
@test occursin("byref", ir_str)
151+
152+
dispose(ir)
153+
end
154+
end
155+
66156
@testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin
67157
ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline=true) do t
68158
t[1]
@@ -88,13 +178,35 @@ end
88178
end
89179
end
90180

181+
# struct field loads from kernarg should use s_load, not flat_load
91182
@test @filecheck begin
92183
check"CHECK: s_load_dwordx"
93184
check"CHECK-NOT: flat_load"
94185
GCN.code_native(mod.kernel, Tuple{mod.MyStruct, Ptr{Float64}}; kernel=true)
95186
end
96187
end
97188

189+
@testset "no scratch spills for small struct kernarg" begin
190+
mod = @eval module $(gensym())
191+
struct SmallStruct
192+
x::Float64
193+
y::Float64
194+
end
195+
196+
function kernel(s::SmallStruct, out::Ptr{Float64})
197+
unsafe_store!(out, s.x + s.y)
198+
return
199+
end
200+
end
201+
202+
# a small struct kernel should not need scratch memory
203+
@test @filecheck begin
204+
check"CHECK: .private_segment_fixed_size: 0"
205+
GCN.code_native(mod.kernel, Tuple{mod.SmallStruct, Ptr{Float64}};
206+
dump_module=true, kernel=true)
207+
end
208+
end
209+
98210
@testset "skip scalar trap" begin
99211
mod = @eval module $(gensym())
100212
workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ())

0 commit comments

Comments
 (0)