Skip to content

Commit ee4a85c

Browse files
authored
Merge pull request #772 from gbaraldi/gcn-byref-kernel-args
GCN: use byref instead of byval+lower_byval for kernel arguments
2 parents 1df660b + 3df173f commit ee4a85c

4 files changed

Lines changed: 316 additions & 6 deletions

File tree

src/gcn.jl

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,161 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev
4040
const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free")
4141
isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics)
4242

43+
pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true
44+
4345
function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
4446
mod::LLVM.Module, entry::LLVM.Function)
4547
lower_throw_extra!(mod)
4648

4749
if job.config.kernel
4850
# calling convention
4951
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
50-
51-
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
52-
entry = lower_byval(job, mod, entry)
5352
end
5453

5554
return entry
5655
end
5756

57+
function finish_ir!(
58+
@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module,
59+
entry::LLVM.Function
60+
)
61+
if job.config.kernel
62+
entry = add_kernarg_address_spaces!(job, mod, entry)
63+
64+
# optimize after address space rewriting: propagate addrspace(4) through
65+
# the addrspacecast chains, then clean up newly-exposed opportunities
66+
tm = llvm_machine(job.config.target)
67+
@dispose pb=NewPMPassBuilder() begin
68+
add!(pb, NewPMFunctionPassManager()) do fpm
69+
add!(fpm, InferAddressSpacesPass())
70+
add!(fpm, SROAPass())
71+
add!(fpm, InstCombinePass())
72+
add!(fpm, EarlyCSEPass())
73+
add!(fpm, SimplifyCFGPass())
74+
end
75+
run!(pb, mod, tm)
76+
end
77+
end
78+
return entry
79+
end
80+
81+
# Rewrite byref kernel parameters from flat (addrspace 0) to constant (addrspace 4).
82+
#
83+
# On AMDGPU, kernel arguments reside in the constant address space (addrspace 4),
84+
# which is scalar-loadable via s_load. Julia initially emits byref parameters as
85+
# pointers in addrspace(11) (tracked/derived), but RemoveJuliaAddrspacesPass strips
86+
# all non-integral address spaces to flat (addrspace 0) during optimization. This pass
87+
# restores addrspace(4) on byref parameters so that the backend can emit s_load
88+
# instead of flat_load for struct field accesses.
89+
#
90+
# NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
91+
# converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
92+
function add_kernarg_address_spaces!(
93+
@nospecialize(job::CompilerJob), mod::LLVM.Module,
94+
f::LLVM.Function
95+
)
96+
ft = function_type(f)
97+
98+
# find the byref parameters by checking for the byref attribute directly,
99+
# rather than re-classifying arguments (which can fail on typed-pointer LLVM
100+
# due to element type mismatches in classify_arguments assertions).
101+
byref_kind = LLVM.API.LLVMGetEnumAttributeKindForName("byref", 5)
102+
byref_mask = BitVector(undef, length(parameters(ft)))
103+
for i in 1:length(parameters(ft))
104+
attrs = collect(parameter_attributes(f, i))
105+
byref_mask[i] = any(a -> a isa TypeAttribute && kind(a) == byref_kind, attrs)
106+
end
107+
108+
# check if any flat pointer byref params need rewriting
109+
needs_rewrite = false
110+
for (i, param) in enumerate(parameters(ft))
111+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
112+
needs_rewrite = true
113+
break
114+
end
115+
end
116+
needs_rewrite || return f
117+
118+
# generate the new function type with constant address space on byref params
119+
new_types = LLVMType[]
120+
for (i, param) in enumerate(parameters(ft))
121+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
122+
if supports_typed_pointers(context())
123+
push!(new_types, LLVM.PointerType(eltype(param), #=constant=# 4))
124+
else
125+
push!(new_types, LLVM.PointerType(#=constant=# 4))
126+
end
127+
else
128+
push!(new_types, param)
129+
end
130+
end
131+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
132+
new_f = LLVM.Function(mod, "", new_ft)
133+
linkage!(new_f, linkage(f))
134+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
135+
LLVM.name!(new_arg, LLVM.name(arg))
136+
end
137+
138+
# insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
139+
# (which expects flat pointers) continues to work. The AMDGPU backend's
140+
# AMDGPULowerKernelArguments traces these casts and produces s_load.
141+
new_args = LLVM.Value[]
142+
@dispose builder=IRBuilder() begin
143+
entry_bb = BasicBlock(new_f, "conversion")
144+
position!(builder, entry_bb)
145+
146+
for (i, param) in enumerate(parameters(ft))
147+
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
148+
cast = addrspacecast!(builder, parameters(new_f)[i], param)
149+
push!(new_args, cast)
150+
else
151+
push!(new_args, parameters(new_f)[i])
152+
end
153+
end
154+
155+
# clone the original function body
156+
value_map = Dict{LLVM.Value, LLVM.Value}(
157+
param => new_args[i] for (i, param) in enumerate(parameters(f))
158+
)
159+
value_map[f] = new_f
160+
clone_into!(
161+
new_f, f; value_map,
162+
changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
163+
)
164+
165+
# fall through from conversion block to cloned entry
166+
br!(builder, blocks(new_f)[2])
167+
end
168+
169+
# copy parameter attributes AFTER clone_into!, because CloneFunctionInto
170+
# overwrites all attributes via setAttributes. For byref params, the VMap
171+
# maps old args to addrspacecast instructions (not Arguments), so LLVM's
172+
# attribute remapping silently drops them. We must re-add them here.
173+
for i in 1:length(parameters(ft))
174+
for attr in collect(parameter_attributes(f, i))
175+
push!(parameter_attributes(new_f, i), attr)
176+
end
177+
end
178+
179+
# replace the old function
180+
fn = LLVM.name(f)
181+
prune_constexpr_uses!(f)
182+
@assert isempty(uses(f))
183+
replace_metadata_uses!(f, new_f)
184+
erase!(f)
185+
LLVM.name!(new_f, fn)
186+
187+
# clean up the extra conversion block
188+
@dispose pb=NewPMPassBuilder() begin
189+
add!(pb, NewPMFunctionPassManager()) do fpm
190+
add!(fpm, SimplifyCFGPass())
191+
end
192+
run!(pb, mod)
193+
end
194+
195+
return functions(mod)[fn]
196+
end
197+
58198

59199
## LLVM passes
60200

src/interface.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
272272
# Does the target need to pass kernel arguments by value?
273273
pass_by_value(@nospecialize(job::CompilerJob)) = true
274274

275+
# Should the target use byref instead of byval+lower_byval for kernel arguments?
276+
# When true, aggregate arguments are passed as pointers with the byref attribute,
277+
# allowing the backend to load fields directly from the argument memory (e.g. kernarg
278+
# segment on AMDGPU) instead of materializing the entire struct via first-class aggregates.
279+
pass_by_ref(@nospecialize(job::CompilerJob)) = false
280+
275281
# whether pointer is a valid call target
276282
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false
277283

src/irgen.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ function irgen(@nospecialize(job::CompilerJob))
9494
for arg in args
9595
if arg.cc == BITS_REF
9696
llvm_typ = convert(LLVMType, arg.typ)
97-
attr = TypeAttribute("byval", llvm_typ)
97+
if pass_by_ref(job)
98+
attr = TypeAttribute("byref", llvm_typ)
99+
else
100+
attr = TypeAttribute("byval", llvm_typ)
101+
end
98102
push!(parameter_attributes(entry, arg.idx), attr)
99103
end
100104
end

test/gcn.jl

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,124 @@ end
3737
end
3838
end
3939

40+
@testset "kernarg address space for byref parameters" begin
41+
mod = @eval module $(gensym())
42+
struct MyStruct
43+
x::Float64
44+
y::Float64
45+
end
46+
47+
function kernel(s::MyStruct)
48+
s.x + s.y
49+
return
50+
end
51+
end
52+
53+
# byref struct params should be ptr addrspace(4) in kernel IR
54+
@test @filecheck begin
55+
check"TYPED: define amdgpu_kernel void @_Z6kernel8MyStruct({{.*}} addrspace(4)*"
56+
check"OPAQUE: define amdgpu_kernel void @_Z6kernel8MyStruct(ptr addrspace(4)"
57+
GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=true)
58+
end
59+
60+
# non-kernel should NOT have addrspace(4)
61+
@test @filecheck begin
62+
check"CHECK-NOT: addrspace(4)"
63+
GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=false)
64+
end
65+
end
66+
67+
@testset "byref attribute preserved on kernarg parameters" begin
68+
mod = @eval module $(gensym())
69+
struct LargeStruct
70+
a::Float64
71+
b::Float64
72+
c::Float64
73+
d::Float64
74+
end
75+
76+
function kernel(s::LargeStruct, out::Ptr{Float64})
77+
unsafe_store!(out, s.a + s.b + s.c + s.d)
78+
return
79+
end
80+
end
81+
82+
# the byref attribute must survive the addrspace rewrite (clone_into! can drop it)
83+
@test @filecheck begin
84+
check"CHECK: byref"
85+
check"CHECK: addrspace(4)"
86+
GCN.code_llvm(mod.kernel, Tuple{mod.LargeStruct, Ptr{Float64}};
87+
dump_module=true, kernel=true)
88+
end
89+
end
90+
91+
@testset "mixed byref and scalar kernel parameters" begin
92+
mod = @eval module $(gensym())
93+
struct Params
94+
x::Float64
95+
y::Float64
96+
end
97+
98+
function kernel(a::Float64, s::Params, out::Ptr{Float64})
99+
unsafe_store!(out, a + s.x + s.y)
100+
return
101+
end
102+
end
103+
104+
# scalar Float64 should NOT be in addrspace(4),
105+
# only the struct byref param should be.
106+
# NOTE: Ptr{Float64} is lowered to i64 on Julia ≤1.11 and ptr on Julia 1.12+.
107+
@test @filecheck begin
108+
check"CHECK: define amdgpu_kernel void"
109+
check"CHECK-SAME: double"
110+
check"TYPED-SAME: {{.*}} addrspace(4)*"
111+
check"OPAQUE-SAME: ptr addrspace(4)"
112+
check"CHECK-SAME: {{(i64|ptr)}}"
113+
GCN.code_llvm(mod.kernel, Tuple{Float64, mod.Params, Ptr{Float64}};
114+
dump_module=true, kernel=true)
115+
end
116+
end
117+
118+
@testset "add_kernarg_address_spaces! rewrites IR correctly" begin
119+
mod = @eval module $(gensym())
120+
struct KernelArgs
121+
x::Float64
122+
y::Float64
123+
z::Float64
124+
end
125+
126+
function kernel(s::KernelArgs, scale::Float64, out::Ptr{Float64})
127+
unsafe_store!(out, (s.x + s.y + s.z) * scale)
128+
return
129+
end
130+
end
131+
132+
job, _ = GCN.create_job(mod.kernel, Tuple{mod.KernelArgs, Float64, Ptr{Float64}};
133+
kernel=true)
134+
JuliaContext() do ctx
135+
ir, meta = GPUCompiler.compile(:llvm, job)
136+
137+
entry = meta.entry
138+
ft = function_type(entry)
139+
params = parameters(ft)
140+
141+
# the struct byref param should be ptr addrspace(4)
142+
has_as4 = any(p -> p isa LLVM.PointerType && addrspace(p) == 4, params)
143+
@test has_as4
144+
145+
# non-struct params (double, and i64/ptr for Ptr{Float64}) should NOT
146+
# be in addrspace(4). Ptr{Float64} is i64 on Julia ≤1.11, ptr on 1.12+.
147+
non_byref = filter(p -> !(p isa LLVM.PointerType && addrspace(p) == 4), params)
148+
@test !isempty(non_byref) # double (and i64 or ptr) params
149+
150+
# byref attribute must be present
151+
ir_str = string(ir)
152+
@test occursin("byref", ir_str)
153+
154+
dispose(ir)
155+
end
156+
end
157+
40158
@testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin
41159
ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline=true) do t
42160
t[1]
@@ -49,6 +167,48 @@ end
49167
############################################################################################
50168
@testset "assembly" begin
51169

170+
@testset "s_load for kernarg struct access" begin
171+
mod = @eval module $(gensym())
172+
struct MyStruct
173+
x::Float64
174+
y::Float64
175+
end
176+
177+
function kernel(s::MyStruct, out::Ptr{Float64})
178+
unsafe_store!(out, s.x + s.y)
179+
return
180+
end
181+
end
182+
183+
# struct field loads from kernarg should use s_load, not flat_load
184+
@test @filecheck begin
185+
check"CHECK: s_load_dwordx"
186+
check"CHECK-NOT: flat_load"
187+
GCN.code_native(mod.kernel, Tuple{mod.MyStruct, Ptr{Float64}}; kernel=true)
188+
end
189+
end
190+
191+
@testset "no scratch spills for small struct kernarg" begin
192+
mod = @eval module $(gensym())
193+
struct SmallStruct
194+
x::Float64
195+
y::Float64
196+
end
197+
198+
function kernel(s::SmallStruct, out::Ptr{Float64})
199+
unsafe_store!(out, s.x + s.y)
200+
return
201+
end
202+
end
203+
204+
# a small struct kernel should not need scratch memory
205+
@test @filecheck begin
206+
check"CHECK: .private_segment_fixed_size: 0"
207+
GCN.code_native(mod.kernel, Tuple{mod.SmallStruct, Ptr{Float64}};
208+
dump_module=true, kernel=true)
209+
end
210+
end
211+
52212
@testset "skip scalar trap" begin
53213
mod = @eval module $(gensym())
54214
workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ())
@@ -101,9 +261,9 @@ end
101261
end
102262

103263
@test @filecheck begin
104-
check"CHECK-NOT: .amdhsa_kernel {{(julia|j)_nonentry_[0-9]+}}"
105264
check"CHECK: .type {{(julia|j)_nonentry_[0-9]+}},@function"
106-
check"CHECK: .amdhsa_kernel _Z5entry5Int64"
265+
check"CHECK: .symbol:{{.*}}_Z5entry5Int64.kd"
266+
check"CHECK-NOT: .symbol:{{.*}}nonentry"
107267
GCN.code_native(mod.entry, Tuple{Int64}; dump_module=true, kernel=true)
108268
end
109269
end

0 commit comments

Comments
 (0)