@@ -49,15 +49,126 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
4949 if job. config. kernel
5050 # calling convention
5151 callconv! (entry, LLVM. API. LLVMAMDGPUKERNELCallConv)
52-
53- # with byref, the AMDGPU backend's AMDGPULowerKernelArguments pass
54- # will handle loading from the kernarg segment directly.
55- # no need for lower_byval or manual rewriting.
5652 end
5753
5854 return entry
5955end
6056
57+ function finish_ir! (@nospecialize (job:: CompilerJob{GCNCompilerTarget} ), mod:: LLVM.Module ,
58+ entry:: LLVM.Function )
59+ if job. config. kernel
60+ entry = add_kernarg_address_spaces! (job, mod, entry)
61+ end
62+ return entry
63+ end
64+
65+ # Rewrite byref kernel parameters from flat (addrspace 0) to kernarg (addrspace 4).
66+ #
67+ # On AMDGPU, the kernarg segment is in address space 4 and is scalar-loadable via s_load.
68+ # Clang emits byref parameters as `ptr addrspace(4)` from the frontend, but Julia's
69+ # RemoveJuliaAddrspacesPass strips all address spaces to flat. This pass restores the
70+ # correct address space so that struct field loads from byref arguments become s_load
71+ # instead of flat_load.
72+ #
73+ # NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
74+ # converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
75+ function add_kernarg_address_spaces! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
76+ f:: LLVM.Function )
77+ ft = function_type (f)
78+
79+ # find the byref parameters
80+ byref_mask = BitVector (undef, length (parameters (ft)))
81+ args = classify_arguments (job, ft; post_optimization= job. config. optimize)
82+ filter! (args) do arg
83+ arg. cc != GHOST
84+ end
85+ for arg in args
86+ byref_mask[arg. idx] = (arg. cc == BITS_REF || arg. cc == KERNEL_STATE)
87+ end
88+
89+ # check if any flat pointer byref params need rewriting
90+ needs_rewrite = false
91+ for (i, param) in enumerate (parameters (ft))
92+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
93+ needs_rewrite = true
94+ break
95+ end
96+ end
97+ needs_rewrite || return f
98+
99+ # generate the new function type with kernarg address space on byref params
100+ new_types = LLVMType[]
101+ for (i, param) in enumerate (parameters (ft))
102+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
103+ push! (new_types, LLVM. PointerType (#= kernarg=# 4 ))
104+ else
105+ push! (new_types, param)
106+ end
107+ end
108+ new_ft = LLVM. FunctionType (return_type (ft), new_types)
109+ new_f = LLVM. Function (mod, " " , new_ft)
110+ linkage! (new_f, linkage (f))
111+ for (arg, new_arg) in zip (parameters (f), parameters (new_f))
112+ LLVM. name! (new_arg, LLVM. name (arg))
113+ end
114+
115+ # insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
116+ # (which expects flat pointers) continues to work. InferAddressSpaces will then
117+ # propagate addrspace(4) through GEPs and loads, eliminating the casts.
118+ new_args = LLVM. Value[]
119+ @dispose builder= IRBuilder () begin
120+ entry_bb = BasicBlock (new_f, " conversion" )
121+ position! (builder, entry_bb)
122+
123+ for (i, param) in enumerate (parameters (ft))
124+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
125+ cast = addrspacecast! (builder, parameters (new_f)[i], LLVM. PointerType (0 ))
126+ push! (new_args, cast)
127+ else
128+ push! (new_args, parameters (new_f)[i])
129+ end
130+ for attr in collect (parameter_attributes (f, i))
131+ push! (parameter_attributes (new_f, i), attr)
132+ end
133+ end
134+
135+ # clone the original function body
136+ value_map = Dict {LLVM.Value, LLVM.Value} (
137+ param => new_args[i] for (i, param) in enumerate (parameters (f))
138+ )
139+ value_map[f] = new_f
140+ clone_into! (new_f, f; value_map,
141+ changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
142+
143+ # fall through from conversion block to cloned entry
144+ br! (builder, blocks (new_f)[2 ])
145+ end
146+
147+ # replace the old function
148+ fn = LLVM. name (f)
149+ prune_constexpr_uses! (f)
150+ @assert isempty (uses (f))
151+ replace_metadata_uses! (f, new_f)
152+ erase! (f)
153+ LLVM. name! (new_f, fn)
154+
155+ # propagate addrspace(4) through GEPs and loads, then clean up
156+ @dispose pb= NewPMPassBuilder () begin
157+ add! (pb, NewPMFunctionPassManager ()) do fpm
158+ add! (fpm, InferAddressSpacesPass ())
159+ end
160+ add! (pb, NewPMFunctionPassManager ()) do fpm
161+ add! (fpm, SimplifyCFGPass ())
162+ add! (fpm, SROAPass ())
163+ add! (fpm, EarlyCSEPass ())
164+ add! (fpm, InstCombinePass ())
165+ end
166+ run! (pb, mod)
167+ end
168+
169+ return functions (mod)[fn]
170+ end
171+
61172
62173# # LLVM passes
63174
0 commit comments