Skip to content

Commit 047e172

Browse files
gbaraldiclaude
andcommitted
Simplify: let AMDGPU backend handle byref natively
Remove manual rewrite_byref_addrspaces!. The AMDGPU backend's AMDGPULowerKernelArguments pass already knows how to handle ptr byref(T) on amdgpu_kernel functions — it rewrites the pointer to load from the kernarg segment (addrspace 4) automatically. The previous manual approaches (addrspacecast, load→alloca→store) conflicted with the backend's own lowering. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2cce5a4 commit 047e172

1 file changed

Lines changed: 3 additions & 107 deletions

File tree

src/gcn.jl

Lines changed: 3 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -42,111 +42,6 @@ isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsic
4242

4343
pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true
4444

45-
# AMDGPU constant/kernarg address space
46-
const GCN_ADDRSPACE_CONSTANT = 4
47-
48-
# Rewrite byref pointer parameters from addrspace 0 to addrspace 4 (kernarg),
49-
# loading the data into local allocas so the function body can use generic pointers.
50-
# SROA will decompose the allocas during the optimization pipeline that follows.
51-
function rewrite_byref_addrspaces!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
52-
mod::LLVM.Module, f::LLVM.Function)
53-
ft = function_type(f)
54-
55-
# find byref parameters and their types
56-
args = classify_arguments(job, ft)
57-
filter!(args) do arg
58-
arg.cc != GHOST
59-
end
60-
61-
byref = BitVector(undef, length(parameters(ft)))
62-
byref_types = Vector{Any}(undef, length(parameters(ft)))
63-
for i in 1:length(byref)
64-
byref[i] = false
65-
for attr in collect(parameter_attributes(f, i))
66-
if kind(attr) == kind(TypeAttribute("byref", LLVM.VoidType()))
67-
byref[i] = true
68-
end
69-
end
70-
end
71-
for arg in args
72-
if arg.idx !== nothing && byref[arg.idx]
73-
byref_types[arg.idx] = arg.typ
74-
end
75-
end
76-
any(byref) || return f
77-
78-
# build new function type with addrspace(4) pointers for byref params
79-
new_types = LLVMType[]
80-
for (i, param) in enumerate(parameters(ft))
81-
if byref[i]
82-
push!(new_types, LLVM.PointerType(GCN_ADDRSPACE_CONSTANT))
83-
else
84-
push!(new_types, param)
85-
end
86-
end
87-
new_ft = LLVM.FunctionType(return_type(ft), new_types)
88-
new_f = LLVM.Function(mod, "", new_ft)
89-
linkage!(new_f, linkage(f))
90-
callconv!(new_f, callconv(f))
91-
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
92-
LLVM.name!(new_arg, LLVM.name(arg))
93-
end
94-
95-
# copy parameter attributes, ensuring byref is preserved with correct type
96-
for (i, _) in enumerate(parameters(ft))
97-
for attr in collect(parameter_attributes(f, i))
98-
push!(parameter_attributes(new_f, i), attr)
99-
end
100-
# explicitly re-add byref with the correct type, in case the copy
101-
# dropped it due to the parameter type change
102-
if byref[i]
103-
llvm_typ = convert(LLVMType, byref_types[i])
104-
push!(parameter_attributes(new_f, i), TypeAttribute("byref", llvm_typ))
105-
end
106-
end
107-
108-
# load byref arguments from addrspace(4) into local allocas
109-
new_args = LLVM.Value[]
110-
@dispose builder=IRBuilder() begin
111-
entry = BasicBlock(new_f, "conversion")
112-
position!(builder, entry)
113-
114-
for (i, param) in enumerate(parameters(ft))
115-
if byref[i]
116-
# load the value from the kernarg pointer and store into a stack slot,
117-
# so the function body can keep using addrspace(0) pointers.
118-
# SROA will decompose this during optimization.
119-
llvm_typ = convert(LLVMType, byref_types[i])
120-
val = load!(builder, llvm_typ, parameters(new_f)[i])
121-
ptr = alloca!(builder, llvm_typ)
122-
store!(builder, val, ptr)
123-
push!(new_args, ptr)
124-
else
125-
push!(new_args, parameters(new_f)[i])
126-
end
127-
end
128-
129-
value_map = Dict{LLVM.Value, LLVM.Value}(
130-
param => new_args[i] for (i, param) in enumerate(parameters(f))
131-
)
132-
value_map[f] = new_f
133-
clone_into!(new_f, f; value_map,
134-
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
135-
136-
br!(builder, blocks(new_f)[2])
137-
end
138-
139-
# replace old function
140-
fn = LLVM.name(f)
141-
prune_constexpr_uses!(f)
142-
@assert isempty(uses(f))
143-
replace_metadata_uses!(f, new_f)
144-
erase!(f)
145-
LLVM.name!(new_f, fn)
146-
147-
return new_f
148-
end
149-
15045
function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
15146
mod::LLVM.Module, entry::LLVM.Function)
15247
lower_throw_extra!(mod)
@@ -155,8 +50,9 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
15550
# calling convention
15651
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
15752

158-
# rewrite byref parameters to use the kernarg address space
159-
entry = rewrite_byref_addrspaces!(job, mod, entry)
53+
# with byref, the AMDGPU backend's AMDGPULowerKernelArguments pass
54+
# will handle loading from the kernarg segment directly.
55+
# no need for lower_byval or manual rewriting.
16056
end
16157

16258
return entry

0 commit comments

Comments
 (0)