@@ -40,21 +40,161 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev
4040const gcn_intrinsics = () # TODO : ("vprintf", "__assertfail", "malloc", "free")
4141isintrinsic (:: CompilerJob{GCNCompilerTarget} , fn:: String ) = in (fn, gcn_intrinsics)
4242
43+ pass_by_ref (@nospecialize (job:: CompilerJob{GCNCompilerTarget} )) = true
44+
4345function finish_module! (@nospecialize (job:: CompilerJob{GCNCompilerTarget} ),
4446 mod:: LLVM.Module , entry:: LLVM.Function )
4547 lower_throw_extra! (mod)
4648
4749 if job. config. kernel
4850 # calling convention
4951 callconv! (entry, LLVM. API. LLVMAMDGPUKERNELCallConv)
50-
51- # work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
52- entry = lower_byval (job, mod, entry)
5352 end
5453
5554 return entry
5655end
5756
57+ function finish_ir! (
58+ @nospecialize (job:: CompilerJob{GCNCompilerTarget} ), mod:: LLVM.Module ,
59+ entry:: LLVM.Function
60+ )
61+ if job. config. kernel
62+ entry = add_kernarg_address_spaces! (job, mod, entry)
63+
64+ # optimize after address space rewriting: propagate addrspace(4) through
65+ # the addrspacecast chains, then clean up newly-exposed opportunities
66+ tm = llvm_machine (job. config. target)
67+ @dispose pb= NewPMPassBuilder () begin
68+ add! (pb, NewPMFunctionPassManager ()) do fpm
69+ add! (fpm, InferAddressSpacesPass ())
70+ add! (fpm, SROAPass ())
71+ add! (fpm, InstCombinePass ())
72+ add! (fpm, EarlyCSEPass ())
73+ add! (fpm, SimplifyCFGPass ())
74+ end
75+ run! (pb, mod, tm)
76+ end
77+ end
78+ return entry
79+ end
80+
81+ # Rewrite byref kernel parameters from flat (addrspace 0) to constant (addrspace 4).
82+ #
83+ # On AMDGPU, kernel arguments reside in the constant address space (addrspace 4),
84+ # which is scalar-loadable via s_load. Julia initially emits byref parameters as
85+ # pointers in addrspace(11) (tracked/derived), but RemoveJuliaAddrspacesPass strips
86+ # all non-integral address spaces to flat (addrspace 0) during optimization. This pass
87+ # restores addrspace(4) on byref parameters so that the backend can emit s_load
88+ # instead of flat_load for struct field accesses.
89+ #
90+ # NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
91+ # converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
92+ function add_kernarg_address_spaces! (
93+ @nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
94+ f:: LLVM.Function
95+ )
96+ ft = function_type (f)
97+
98+ # find the byref parameters by checking for the byref attribute directly,
99+ # rather than re-classifying arguments (which can fail on typed-pointer LLVM
100+ # due to element type mismatches in classify_arguments assertions).
101+ byref_kind = LLVM. API. LLVMGetEnumAttributeKindForName (" byref" , 5 )
102+ byref_mask = BitVector (undef, length (parameters (ft)))
103+ for i in 1 : length (parameters (ft))
104+ attrs = collect (parameter_attributes (f, i))
105+ byref_mask[i] = any (a -> a isa TypeAttribute && kind (a) == byref_kind, attrs)
106+ end
107+
108+ # check if any flat pointer byref params need rewriting
109+ needs_rewrite = false
110+ for (i, param) in enumerate (parameters (ft))
111+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
112+ needs_rewrite = true
113+ break
114+ end
115+ end
116+ needs_rewrite || return f
117+
118+ # generate the new function type with constant address space on byref params
119+ new_types = LLVMType[]
120+ for (i, param) in enumerate (parameters (ft))
121+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
122+ if supports_typed_pointers (context ())
123+ push! (new_types, LLVM. PointerType (eltype (param), #= constant=# 4 ))
124+ else
125+ push! (new_types, LLVM. PointerType (#= constant=# 4 ))
126+ end
127+ else
128+ push! (new_types, param)
129+ end
130+ end
131+ new_ft = LLVM. FunctionType (return_type (ft), new_types)
132+ new_f = LLVM. Function (mod, " " , new_ft)
133+ linkage! (new_f, linkage (f))
134+ for (arg, new_arg) in zip (parameters (f), parameters (new_f))
135+ LLVM. name! (new_arg, LLVM. name (arg))
136+ end
137+
138+ # insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
139+ # (which expects flat pointers) continues to work. The AMDGPU backend's
140+ # AMDGPULowerKernelArguments traces these casts and produces s_load.
141+ new_args = LLVM. Value[]
142+ @dispose builder= IRBuilder () begin
143+ entry_bb = BasicBlock (new_f, " conversion" )
144+ position! (builder, entry_bb)
145+
146+ for (i, param) in enumerate (parameters (ft))
147+ if byref_mask[i] && param isa LLVM. PointerType && addrspace (param) == 0
148+ cast = addrspacecast! (builder, parameters (new_f)[i], param)
149+ push! (new_args, cast)
150+ else
151+ push! (new_args, parameters (new_f)[i])
152+ end
153+ end
154+
155+ # clone the original function body
156+ value_map = Dict {LLVM.Value, LLVM.Value} (
157+ param => new_args[i] for (i, param) in enumerate (parameters (f))
158+ )
159+ value_map[f] = new_f
160+ clone_into! (
161+ new_f, f; value_map,
162+ changes = LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges
163+ )
164+
165+ # fall through from conversion block to cloned entry
166+ br! (builder, blocks (new_f)[2 ])
167+ end
168+
169+ # copy parameter attributes AFTER clone_into!, because CloneFunctionInto
170+ # overwrites all attributes via setAttributes. For byref params, the VMap
171+ # maps old args to addrspacecast instructions (not Arguments), so LLVM's
172+ # attribute remapping silently drops them. We must re-add them here.
173+ for i in 1 : length (parameters (ft))
174+ for attr in collect (parameter_attributes (f, i))
175+ push! (parameter_attributes (new_f, i), attr)
176+ end
177+ end
178+
179+ # replace the old function
180+ fn = LLVM. name (f)
181+ prune_constexpr_uses! (f)
182+ @assert isempty (uses (f))
183+ replace_metadata_uses! (f, new_f)
184+ erase! (f)
185+ LLVM. name! (new_f, fn)
186+
187+ # clean up the extra conversion block
188+ @dispose pb= NewPMPassBuilder () begin
189+ add! (pb, NewPMFunctionPassManager ()) do fpm
190+ add! (fpm, SimplifyCFGPass ())
191+ end
192+ run! (pb, mod)
193+ end
194+
195+ return functions (mod)[fn]
196+ end
197+
58198
59199# # LLVM passes
60200
0 commit comments