Skip to content

Commit 5e07342

Browse files
gbaraldiclaude
andcommitted
Fix byref rewrite: load from addrspace(4) into alloca instead of addrspacecast
The addrspacecast from addrspace(4) to addrspace(0) caused "illegal VGPR to SGPR copy" errors because LLVM couldn't properly lower generic pointer accesses back to the constant address space. Instead, follow Metal's approach: load the struct from the addrspace(4) kernarg pointer into a local alloca (addrspace 5), and let SROA decompose it during the optimization pipeline. This avoids the address space mismatch while still benefiting from byref semantics — the load from addrspace(4) is a scalar load from the kernarg segment, and SROA will eliminate dead fields. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d0fbeb8 commit 5e07342

1 file changed

Lines changed: 22 additions & 5 deletions

File tree

src/gcn.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,20 @@ pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true
4646
const GCN_ADDRSPACE_CONSTANT = 4
4747

4848
# Rewrite byref pointer parameters from addrspace 0 to addrspace 4 (kernarg),
49-
# inserting addrspacecasts so the function body can continue using generic pointers.
49+
# loading the data into local allocas so the function body can use generic pointers.
50+
# SROA will decompose the allocas during the optimization pipeline that follows.
5051
function rewrite_byref_addrspaces!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
5152
mod::LLVM.Module, f::LLVM.Function)
5253
ft = function_type(f)
5354

54-
# find byref parameters
55+
# find byref parameters and their types
56+
args = classify_arguments(job, ft)
57+
filter!(args) do arg
58+
arg.cc != GHOST
59+
end
60+
5561
byref = BitVector(undef, length(parameters(ft)))
62+
byref_types = Vector{Any}(undef, length(parameters(ft)))
5663
for i in 1:length(byref)
5764
byref[i] = false
5865
for attr in collect(parameter_attributes(f, i))
@@ -61,6 +68,11 @@ function rewrite_byref_addrspaces!(@nospecialize(job::CompilerJob{GCNCompilerTar
6168
end
6269
end
6370
end
71+
for arg in args
72+
if arg.idx !== nothing && byref[arg.idx]
73+
byref_types[arg.idx] = arg.typ
74+
end
75+
end
6476
any(byref) || return f
6577

6678
# build new function type with addrspace(4) pointers for byref params
@@ -87,16 +99,21 @@ function rewrite_byref_addrspaces!(@nospecialize(job::CompilerJob{GCNCompilerTar
8799
end
88100
end
89101

90-
# insert addrspacecasts in entry block
102+
# load byref arguments from addrspace(4) into local allocas
91103
new_args = LLVM.Value[]
92104
@dispose builder=IRBuilder() begin
93105
entry = BasicBlock(new_f, "conversion")
94106
position!(builder, entry)
95107

96108
for (i, param) in enumerate(parameters(ft))
97109
if byref[i]
98-
# cast from addrspace(4) to addrspace(0) for the function body
99-
ptr = addrspacecast!(builder, parameters(new_f)[i], param)
110+
# load the value from the kernarg pointer and store into a stack slot,
111+
# so the function body can keep using addrspace(0) pointers.
112+
# SROA will decompose this during optimization.
113+
llvm_typ = convert(LLVMType, byref_types[i])
114+
val = load!(builder, llvm_typ, parameters(new_f)[i])
115+
ptr = alloca!(builder, llvm_typ)
116+
store!(builder, val, ptr)
100117
push!(new_args, ptr)
101118
else
102119
push!(new_args, parameters(new_f)[i])

0 commit comments

Comments
 (0)