|
| 1 | +import tgpu, { d, std } from 'typegpu'; |
| 2 | +import { |
| 3 | + type BitonicSorter, |
| 4 | + type BitonicSorterOptions, |
| 5 | + createBitonicSorter, |
| 6 | + decomposeWorkgroups, |
| 7 | +} from '@typegpu/sort'; |
| 8 | +import { randf } from '@typegpu/noise'; |
| 9 | +import { fullScreenTriangle } from 'typegpu/common'; |
| 10 | +import { defineControls } from '../../common/defineControls.ts'; |
| 11 | + |
| 12 | +const maxBufferSize = await navigator.gpu.requestAdapter().then((adapter) => { |
| 13 | + if (!adapter) { |
| 14 | + throw new Error('No GPU adapter found'); |
| 15 | + } |
| 16 | + const limits = adapter.limits; |
| 17 | + return Math.min(limits.maxStorageBufferBindingSize, limits.maxBufferSize); |
| 18 | +}); |
| 19 | + |
| 20 | +const root = await tgpu.init({ |
| 21 | + device: { |
| 22 | + optionalFeatures: ['timestamp-query'], |
| 23 | + requiredLimits: { |
| 24 | + maxStorageBufferBindingSize: maxBufferSize, |
| 25 | + maxBufferSize: maxBufferSize, |
| 26 | + }, |
| 27 | + }, |
| 28 | +}); |
| 29 | +const hasTimestampQuery = root.enabledFeatures.has('timestamp-query'); |
| 30 | +const querySet = hasTimestampQuery ? root.createQuerySet('timestamp', 2) : null; |
| 31 | + |
| 32 | +const canvas = document.querySelector('canvas') as HTMLCanvasElement; |
| 33 | +const context = root.configureContext({ canvas }); |
| 34 | + |
| 35 | +const presentationFormat = navigator.gpu.getPreferredCanvasFormat(); |
| 36 | + |
| 37 | +const maxSide = Math.floor(Math.sqrt(maxBufferSize / 4)); |
| 38 | +const minLog = 2; // log_2(4) |
| 39 | +const maxLog = Math.floor(Math.log2(maxSide)); |
| 40 | +const arraySizeOptions = Array.from({ length: 8 }, (_, i) => { |
| 41 | + const side = Math.round(2 ** (minLog + (i * (maxLog - minLog)) / 7)); |
| 42 | + return side * side; |
| 43 | +}); |
| 44 | + |
| 45 | +type SortOrderKey = 'ascending' | 'descending' | 'bit-reversed' | 'xor-scatter'; |
| 46 | + |
| 47 | +const sortOrders: Record<SortOrderKey, BitonicSorterOptions> = { |
| 48 | + ascending: {}, |
| 49 | + descending: { |
| 50 | + compare: (a, b) => { |
| 51 | + 'use gpu'; |
| 52 | + return a > b; |
| 53 | + }, |
| 54 | + paddingValue: 0, |
| 55 | + }, |
| 56 | + 'bit-reversed': { |
| 57 | + compare: (a, b) => { |
| 58 | + 'use gpu'; |
| 59 | + return std.reverseBits(a) < std.reverseBits(b); |
| 60 | + }, |
| 61 | + }, |
| 62 | + 'xor-scatter': { |
| 63 | + compare: (a, b) => { |
| 64 | + 'use gpu'; |
| 65 | + return (a ^ 0xaa) < (b ^ 0xaa); |
| 66 | + }, |
| 67 | + }, |
| 68 | +}; |
| 69 | + |
| 70 | +const state = { |
| 71 | + arraySize: arraySizeOptions[2], |
| 72 | + sortOrder: 'ascending' as SortOrderKey, |
| 73 | +}; |
| 74 | + |
| 75 | +const WORKGROUP_SIZE = 256; |
| 76 | + |
| 77 | +const renderLayout = tgpu.bindGroupLayout({ |
| 78 | + data: { |
| 79 | + storage: d.arrayOf(d.u32), |
| 80 | + access: 'readonly', |
| 81 | + }, |
| 82 | +}); |
| 83 | + |
| 84 | +const initLayout = tgpu.bindGroupLayout({ |
| 85 | + data: { |
| 86 | + storage: d.arrayOf(d.u32), |
| 87 | + access: 'mutable', |
| 88 | + }, |
| 89 | +}); |
| 90 | + |
| 91 | +const initSeed = root.createUniform(d.f32, 0); |
| 92 | + |
| 93 | +const fragmentFn = tgpu.fragmentFn({ |
| 94 | + in: { uv: d.vec2f }, |
| 95 | + out: d.vec4f, |
| 96 | +})((input) => { |
| 97 | + const data = renderLayout.$.data; |
| 98 | + const arrayLength = data.length; |
| 99 | + |
| 100 | + const cols = d.u32(std.round(std.sqrt(d.f32(arrayLength)))); |
| 101 | + const rows = d.u32(std.round(arrayLength / cols)); |
| 102 | + |
| 103 | + const col = d.u32(std.floor(input.uv.x * d.f32(cols))); |
| 104 | + const row = d.u32(std.floor(input.uv.y * d.f32(rows))); |
| 105 | + const idx = row * cols + col; |
| 106 | + |
| 107 | + if (idx >= arrayLength) { |
| 108 | + return d.vec4f(0.1, 0.1, 0.1, 1); |
| 109 | + } |
| 110 | + |
| 111 | + const value = data[idx]; |
| 112 | + const normalized = value / 255; |
| 113 | + |
| 114 | + return d.vec4f(normalized, normalized, normalized, 1); |
| 115 | +}); |
| 116 | + |
| 117 | +const initKernel = tgpu.computeFn({ |
| 118 | + workgroupSize: [WORKGROUP_SIZE], |
| 119 | + in: { |
| 120 | + gid: d.builtin.globalInvocationId, |
| 121 | + numWorkgroups: d.builtin.numWorkgroups, |
| 122 | + }, |
| 123 | +})((input) => { |
| 124 | + const spanX = input.numWorkgroups.x * WORKGROUP_SIZE; |
| 125 | + const spanY = input.numWorkgroups.y * spanX; |
| 126 | + const idx = input.gid.x + input.gid.y * spanX + input.gid.z * spanY; |
| 127 | + |
| 128 | + if (idx >= initLayout.$.data.length) { |
| 129 | + return; |
| 130 | + } |
| 131 | + |
| 132 | + randf.seed3(d.vec3f(d.f32(idx & 0xffff), d.f32(idx >> 16), initSeed.$)); |
| 133 | + const n = randf.sample(); |
| 134 | + initLayout.$.data[idx] = d.u32(std.floor(n * 256.0)); |
| 135 | +}); |
| 136 | + |
| 137 | +const renderPipeline = root.createRenderPipeline({ |
| 138 | + vertex: fullScreenTriangle, |
| 139 | + fragment: fragmentFn, |
| 140 | + targets: { format: presentationFormat }, |
| 141 | +}); |
| 142 | + |
| 143 | +const initPipeline = root.createComputePipeline({ compute: initKernel }); |
| 144 | + |
| 145 | +let buffer = root.createBuffer(d.arrayOf(d.u32, state.arraySize)).$usage('storage'); |
| 146 | + |
| 147 | +let bindGroup = root.createBindGroup(renderLayout, { |
| 148 | + data: buffer, |
| 149 | +}); |
| 150 | +let initBindGroup = root.createBindGroup(initLayout, { |
| 151 | + data: buffer, |
| 152 | +}); |
| 153 | + |
| 154 | +function createSorters(buf: typeof buffer) { |
| 155 | + return Object.fromEntries( |
| 156 | + Object.entries(sortOrders).map(([key, opts]) => [key, createBitonicSorter(root, buf, opts)]), |
| 157 | + ) as Record<SortOrderKey, BitonicSorter>; |
| 158 | +} |
| 159 | + |
| 160 | +let sorters = createSorters(buffer); |
| 161 | + |
| 162 | +function recreateBuffer() { |
| 163 | + for (const s of Object.values(sorters)) { |
| 164 | + s.destroy(); |
| 165 | + } |
| 166 | + buffer.destroy(); |
| 167 | + |
| 168 | + buffer = root.createBuffer(d.arrayOf(d.u32, state.arraySize)).$usage('storage'); |
| 169 | + |
| 170 | + bindGroup = root.createBindGroup(renderLayout, { |
| 171 | + data: buffer, |
| 172 | + }); |
| 173 | + |
| 174 | + initBindGroup = root.createBindGroup(initLayout, { |
| 175 | + data: buffer, |
| 176 | + }); |
| 177 | + |
| 178 | + sorters = createSorters(buffer); |
| 179 | +} |
| 180 | + |
| 181 | +function generateRandomArray() { |
| 182 | + const workgroupsTotal = Math.ceil(state.arraySize / WORKGROUP_SIZE); |
| 183 | + const [workgroupsX, workgroupsY, workgroupsZ] = decomposeWorkgroups(workgroupsTotal); |
| 184 | + |
| 185 | + initSeed.write(Math.random()); |
| 186 | + |
| 187 | + initPipeline.with(initBindGroup).dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ); |
| 188 | + |
| 189 | + render(); |
| 190 | +} |
| 191 | + |
| 192 | +function render() { |
| 193 | + renderPipeline.withColorAttachment({ view: context }).with(bindGroup).draw(3); |
| 194 | +} |
| 195 | + |
| 196 | +const overlay = document.getElementById('sort-overlay') as HTMLDivElement; |
| 197 | +const spinnerEl = document.getElementById('sort-spinner') as HTMLDivElement; |
| 198 | +const statusEl = document.getElementById('sort-status') as HTMLSpanElement; |
| 199 | +canvas.parentElement?.appendChild(overlay); |
| 200 | + |
| 201 | +let hideTimeoutId: ReturnType<typeof setTimeout> | null = null; |
| 202 | + |
| 203 | +function showOverlay(text: string, showSpinner = true) { |
| 204 | + if (hideTimeoutId !== null) { |
| 205 | + clearTimeout(hideTimeoutId); |
| 206 | + hideTimeoutId = null; |
| 207 | + } |
| 208 | + spinnerEl.hidden = !showSpinner; |
| 209 | + statusEl.textContent = text; |
| 210 | + overlay.hidden = false; |
| 211 | + overlay.classList.add('visible'); |
| 212 | +} |
| 213 | + |
| 214 | +function hideOverlay(delayMs = 1500) { |
| 215 | + hideTimeoutId = setTimeout(() => { |
| 216 | + hideTimeoutId = null; |
| 217 | + overlay.classList.remove('visible'); |
| 218 | + overlay.addEventListener('transitionend', () => (overlay.hidden = true), { |
| 219 | + once: true, |
| 220 | + }); |
| 221 | + }, delayMs); |
| 222 | +} |
| 223 | + |
| 224 | +async function sort() { |
| 225 | + const sorter = sorters[state.sortOrder]; |
| 226 | + |
| 227 | + showOverlay('Sorting...'); |
| 228 | + sorter.run({ querySet: querySet ?? undefined }); |
| 229 | + |
| 230 | + let gpuTimeMs: number | null = null; |
| 231 | + if (querySet?.available) { |
| 232 | + querySet.resolve(); |
| 233 | + const timestamps = await querySet.read(); |
| 234 | + gpuTimeMs = Number(timestamps[1] - timestamps[0]) / 1_000_000; |
| 235 | + } |
| 236 | + |
| 237 | + render(); |
| 238 | + |
| 239 | + const timeStr = |
| 240 | + gpuTimeMs !== null |
| 241 | + ? ` in ${ |
| 242 | + gpuTimeMs >= 1000 ? `${(gpuTimeMs / 1000).toFixed(2)}s` : `${gpuTimeMs.toFixed(2)}ms` |
| 243 | + }` |
| 244 | + : ''; |
| 245 | + showOverlay(`\u2714 Sorted${timeStr}`, false); |
| 246 | + hideOverlay(); |
| 247 | +} |
| 248 | + |
| 249 | +// #region Example controls & Cleanup |
| 250 | + |
| 251 | +const sortOrderKeys = Object.keys(sortOrders) as SortOrderKey[]; |
| 252 | + |
| 253 | +export const controls = defineControls({ |
| 254 | + 'Array Size': { |
| 255 | + initial: arraySizeOptions[2], |
| 256 | + options: arraySizeOptions, |
| 257 | + onSelectChange: (value) => { |
| 258 | + state.arraySize = isNaN(value) ? 64 : value; |
| 259 | + recreateBuffer(); |
| 260 | + generateRandomArray(); |
| 261 | + }, |
| 262 | + }, |
| 263 | + 'Sort Order': { |
| 264 | + initial: 'ascending', |
| 265 | + options: sortOrderKeys, |
| 266 | + onSelectChange: (value) => { |
| 267 | + state.sortOrder = value; |
| 268 | + }, |
| 269 | + }, |
| 270 | + Reshuffle: { onButtonClick: generateRandomArray }, |
| 271 | + Sort: { onButtonClick: sort }, |
| 272 | +}); |
| 273 | + |
| 274 | +export function onCleanup() { |
| 275 | + for (const s of Object.values(sorters)) { |
| 276 | + s.destroy(); |
| 277 | + } |
| 278 | + root.destroy(); |
| 279 | +} |
| 280 | + |
| 281 | +// #endregion |
0 commit comments