Skip to content

Commit 687f79d

Browse files
gbaraldiclaude
andcommitted
Fix typed-pointer compat and test failures for kernarg rewrite
On LLVM 16 (Julia ≤1.11) with typed pointers, add_kernarg_address_spaces! was creating opaque `ptr addrspace(4)` via `LLVM.PointerType(4)`, which introduced opaque pointers into a typed-pointer module. When InferAddressSpaces then propagated these through memcpy intrinsics (which use typed `i8 addrspace(N)*`), LLVM's verifier rejected the type mismatch. Fix by following Metal's pattern: use `LLVM.PointerType(eltype, 4)` on typed-pointer contexts and the original param type as the addrspacecast target. Also fix tests: Ptr{Float64} is lowered to i64 on Julia ≤1.11 and ptr on 1.12+, so use {{(i64|ptr)}} in filecheck and check for non-byref params instead of non-AS4 pointer params. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7c72b3d commit 687f79d

2 files changed

Lines changed: 15 additions & 11 deletions

File tree

src/gcn.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,15 @@ function add_kernarg_address_spaces!(
115115
end
116116
needs_rewrite || return f
117117

118-
# generate the new function type with kernarg address space on byref params
118+
# generate the new function type with constant address space on byref params
119119
new_types = LLVMType[]
120120
for (i, param) in enumerate(parameters(ft))
121121
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
122-
push!(new_types, LLVM.PointerType(#=constant=# 4))
122+
if supports_typed_pointers(context())
123+
push!(new_types, LLVM.PointerType(eltype(param), #=constant=# 4))
124+
else
125+
push!(new_types, LLVM.PointerType(#=constant=# 4))
126+
end
123127
else
124128
push!(new_types, param)
125129
end
@@ -141,7 +145,7 @@ function add_kernarg_address_spaces!(
141145

142146
for (i, param) in enumerate(parameters(ft))
143147
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
144-
cast = addrspacecast!(builder, parameters(new_f)[i], LLVM.PointerType(0))
148+
cast = addrspacecast!(builder, parameters(new_f)[i], param)
145149
push!(new_args, cast)
146150
else
147151
push!(new_args, parameters(new_f)[i])

test/gcn.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,14 @@ end
100100
end
101101
end
102102

103-
# scalar Float64 and Ptr should NOT be in addrspace(4),
104-
# only the struct byref param should be
103+
# scalar Float64 should NOT be in addrspace(4),
104+
# only the struct byref param should be.
105+
# NOTE: Ptr{Float64} is lowered to i64 on Julia ≤1.11 and ptr on Julia 1.12+.
105106
@test @filecheck begin
106107
check"CHECK: define amdgpu_kernel void"
107108
check"CHECK-SAME: double"
108109
check"CHECK-SAME: ptr addrspace(4)"
109-
check"CHECK-SAME: ptr"
110+
check"CHECK-SAME: {{(i64|ptr)}}"
110111
GCN.code_llvm(mod.kernel, Tuple{Float64, mod.Params, Ptr{Float64}};
111112
dump_module=true, kernel=true)
112113
end
@@ -139,11 +140,10 @@ end
139140
has_as4 = any(p -> p isa LLVM.PointerType && addrspace(p) == 4, params)
140141
@test has_as4
141142

142-
# non-struct params (double, ptr) should NOT be in addrspace(4)
143-
non_as4_ptrs = filter(params) do p
144-
p isa LLVM.PointerType && addrspace(p) != 4
145-
end
146-
@test !isempty(non_as4_ptrs) # the Ptr{Float64} out param
143+
# non-struct params (double, and i64/ptr for Ptr{Float64}) should NOT
144+
# be in addrspace(4). Ptr{Float64} is i64 on Julia ≤1.11, ptr on 1.12+.
145+
non_byref = filter(p -> !(p isa LLVM.PointerType && addrspace(p) == 4), params)
146+
@test !isempty(non_byref) # double (and i64 or ptr) params
147147

148148
# byref attribute must be present
149149
ir_str = string(ir)

0 commit comments

Comments
 (0)