Skip to content

Commit 59ecc8b

Browse files
authored
feat: Add @typegpu/sort scaffolding with simple bitonic sort implementation (#2142)
1 parent f071585 commit 59ecc8b

31 files changed

Lines changed: 1008 additions & 112 deletions

File tree

apps/typegpu-docs/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
"@stackblitz/sdk": "^1.11.0",
3030
"@tailwindcss/vite": "^4.1.18",
3131
"@typegpu/color": "workspace:*",
32-
"@typegpu/concurrent-scan": "workspace:*",
3332
"@typegpu/geometry": "workspace:*",
3433
"@typegpu/noise": "workspace:*",
3534
"@typegpu/sdf": "workspace:*",
35+
"@typegpu/sort": "workspace:*",
3636
"@typegpu/three": "workspace:*",
3737
"@types/react": "^19.1.8",
3838
"@types/react-dom": "^19.1.6",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
<canvas></canvas>
2+
<div id="sort-overlay" hidden>
3+
<div id="sort-spinner"></div>
4+
<span id="sort-status"></span>
5+
</div>
6+
<style>
7+
#sort-overlay {
8+
position: absolute;
9+
inset: 0;
10+
flex-direction: column;
11+
align-items: center;
12+
justify-content: center;
13+
gap: 12px;
14+
backdrop-filter: blur(8px);
15+
background: rgba(23, 23, 36, 0.5);
16+
opacity: 0;
17+
transition: opacity 0.3s ease-in-out;
18+
}
19+
20+
#sort-overlay:not([hidden]) {
21+
display: flex;
22+
}
23+
24+
#sort-overlay.visible {
25+
opacity: 1;
26+
}
27+
28+
#sort-spinner {
29+
width: 40px;
30+
height: 40px;
31+
border: 3px solid #2e295f;
32+
border-top-color: #6453d2;
33+
border-radius: 50%;
34+
animation: spin 0.8s linear infinite;
35+
}
36+
37+
#sort-status {
38+
color: #c3c4f1;
39+
font-size: 0.875rem;
40+
font-family: 'Aeonik', sans-serif;
41+
}
42+
43+
@keyframes spin {
44+
to {
45+
transform: rotate(360deg);
46+
}
47+
}
48+
</style>
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import tgpu, { d, std } from 'typegpu';
2+
import {
3+
type BitonicSorter,
4+
type BitonicSorterOptions,
5+
createBitonicSorter,
6+
decomposeWorkgroups,
7+
} from '@typegpu/sort';
8+
import { randf } from '@typegpu/noise';
9+
import { fullScreenTriangle } from 'typegpu/common';
10+
import { defineControls } from '../../common/defineControls.ts';
11+
12+
const maxBufferSize = await navigator.gpu.requestAdapter().then((adapter) => {
13+
if (!adapter) {
14+
throw new Error('No GPU adapter found');
15+
}
16+
const limits = adapter.limits;
17+
return Math.min(limits.maxStorageBufferBindingSize, limits.maxBufferSize);
18+
});
19+
20+
const root = await tgpu.init({
21+
device: {
22+
optionalFeatures: ['timestamp-query'],
23+
requiredLimits: {
24+
maxStorageBufferBindingSize: maxBufferSize,
25+
maxBufferSize: maxBufferSize,
26+
},
27+
},
28+
});
29+
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
30+
const querySet = hasTimestampQuery ? root.createQuerySet('timestamp', 2) : null;
31+
32+
const canvas = document.querySelector('canvas') as HTMLCanvasElement;
33+
const context = root.configureContext({ canvas });
34+
35+
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
36+
37+
const maxSide = Math.floor(Math.sqrt(maxBufferSize / 4));
38+
const minLog = 2; // log_2(4)
39+
const maxLog = Math.floor(Math.log2(maxSide));
40+
const arraySizeOptions = Array.from({ length: 8 }, (_, i) => {
41+
const side = Math.round(2 ** (minLog + (i * (maxLog - minLog)) / 7));
42+
return side * side;
43+
});
44+
45+
type SortOrderKey = 'ascending' | 'descending' | 'bit-reversed' | 'xor-scatter';
46+
47+
const sortOrders: Record<SortOrderKey, BitonicSorterOptions> = {
48+
ascending: {},
49+
descending: {
50+
compare: (a, b) => {
51+
'use gpu';
52+
return a > b;
53+
},
54+
paddingValue: 0,
55+
},
56+
'bit-reversed': {
57+
compare: (a, b) => {
58+
'use gpu';
59+
return std.reverseBits(a) < std.reverseBits(b);
60+
},
61+
},
62+
'xor-scatter': {
63+
compare: (a, b) => {
64+
'use gpu';
65+
return (a ^ 0xaa) < (b ^ 0xaa);
66+
},
67+
},
68+
};
69+
70+
const state = {
71+
arraySize: arraySizeOptions[2],
72+
sortOrder: 'ascending' as SortOrderKey,
73+
};
74+
75+
const WORKGROUP_SIZE = 256;
76+
77+
const renderLayout = tgpu.bindGroupLayout({
78+
data: {
79+
storage: d.arrayOf(d.u32),
80+
access: 'readonly',
81+
},
82+
});
83+
84+
const initLayout = tgpu.bindGroupLayout({
85+
data: {
86+
storage: d.arrayOf(d.u32),
87+
access: 'mutable',
88+
},
89+
});
90+
91+
const initSeed = root.createUniform(d.f32, 0);
92+
93+
const fragmentFn = tgpu.fragmentFn({
94+
in: { uv: d.vec2f },
95+
out: d.vec4f,
96+
})((input) => {
97+
const data = renderLayout.$.data;
98+
const arrayLength = data.length;
99+
100+
const cols = d.u32(std.round(std.sqrt(d.f32(arrayLength))));
101+
const rows = d.u32(std.round(arrayLength / cols));
102+
103+
const col = d.u32(std.floor(input.uv.x * d.f32(cols)));
104+
const row = d.u32(std.floor(input.uv.y * d.f32(rows)));
105+
const idx = row * cols + col;
106+
107+
if (idx >= arrayLength) {
108+
return d.vec4f(0.1, 0.1, 0.1, 1);
109+
}
110+
111+
const value = data[idx];
112+
const normalized = value / 255;
113+
114+
return d.vec4f(normalized, normalized, normalized, 1);
115+
});
116+
117+
const initKernel = tgpu.computeFn({
118+
workgroupSize: [WORKGROUP_SIZE],
119+
in: {
120+
gid: d.builtin.globalInvocationId,
121+
numWorkgroups: d.builtin.numWorkgroups,
122+
},
123+
})((input) => {
124+
const spanX = input.numWorkgroups.x * WORKGROUP_SIZE;
125+
const spanY = input.numWorkgroups.y * spanX;
126+
const idx = input.gid.x + input.gid.y * spanX + input.gid.z * spanY;
127+
128+
if (idx >= initLayout.$.data.length) {
129+
return;
130+
}
131+
132+
randf.seed3(d.vec3f(d.f32(idx & 0xffff), d.f32(idx >> 16), initSeed.$));
133+
const n = randf.sample();
134+
initLayout.$.data[idx] = d.u32(std.floor(n * 256.0));
135+
});
136+
137+
const renderPipeline = root.createRenderPipeline({
138+
vertex: fullScreenTriangle,
139+
fragment: fragmentFn,
140+
targets: { format: presentationFormat },
141+
});
142+
143+
const initPipeline = root.createComputePipeline({ compute: initKernel });
144+
145+
let buffer = root.createBuffer(d.arrayOf(d.u32, state.arraySize)).$usage('storage');
146+
147+
let bindGroup = root.createBindGroup(renderLayout, {
148+
data: buffer,
149+
});
150+
let initBindGroup = root.createBindGroup(initLayout, {
151+
data: buffer,
152+
});
153+
154+
function createSorters(buf: typeof buffer) {
155+
return Object.fromEntries(
156+
Object.entries(sortOrders).map(([key, opts]) => [key, createBitonicSorter(root, buf, opts)]),
157+
) as Record<SortOrderKey, BitonicSorter>;
158+
}
159+
160+
let sorters = createSorters(buffer);
161+
162+
function recreateBuffer() {
163+
for (const s of Object.values(sorters)) {
164+
s.destroy();
165+
}
166+
buffer.destroy();
167+
168+
buffer = root.createBuffer(d.arrayOf(d.u32, state.arraySize)).$usage('storage');
169+
170+
bindGroup = root.createBindGroup(renderLayout, {
171+
data: buffer,
172+
});
173+
174+
initBindGroup = root.createBindGroup(initLayout, {
175+
data: buffer,
176+
});
177+
178+
sorters = createSorters(buffer);
179+
}
180+
181+
function generateRandomArray() {
182+
const workgroupsTotal = Math.ceil(state.arraySize / WORKGROUP_SIZE);
183+
const [workgroupsX, workgroupsY, workgroupsZ] = decomposeWorkgroups(workgroupsTotal);
184+
185+
initSeed.write(Math.random());
186+
187+
initPipeline.with(initBindGroup).dispatchWorkgroups(workgroupsX, workgroupsY, workgroupsZ);
188+
189+
render();
190+
}
191+
192+
function render() {
193+
renderPipeline.withColorAttachment({ view: context }).with(bindGroup).draw(3);
194+
}
195+
196+
const overlay = document.getElementById('sort-overlay') as HTMLDivElement;
197+
const spinnerEl = document.getElementById('sort-spinner') as HTMLDivElement;
198+
const statusEl = document.getElementById('sort-status') as HTMLSpanElement;
199+
canvas.parentElement?.appendChild(overlay);
200+
201+
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
202+
203+
function showOverlay(text: string, showSpinner = true) {
204+
if (hideTimeoutId !== null) {
205+
clearTimeout(hideTimeoutId);
206+
hideTimeoutId = null;
207+
}
208+
spinnerEl.hidden = !showSpinner;
209+
statusEl.textContent = text;
210+
overlay.hidden = false;
211+
overlay.classList.add('visible');
212+
}
213+
214+
function hideOverlay(delayMs = 1500) {
215+
hideTimeoutId = setTimeout(() => {
216+
hideTimeoutId = null;
217+
overlay.classList.remove('visible');
218+
overlay.addEventListener('transitionend', () => (overlay.hidden = true), {
219+
once: true,
220+
});
221+
}, delayMs);
222+
}
223+
224+
async function sort() {
225+
const sorter = sorters[state.sortOrder];
226+
227+
showOverlay('Sorting...');
228+
sorter.run({ querySet: querySet ?? undefined });
229+
230+
let gpuTimeMs: number | null = null;
231+
if (querySet?.available) {
232+
querySet.resolve();
233+
const timestamps = await querySet.read();
234+
gpuTimeMs = Number(timestamps[1] - timestamps[0]) / 1_000_000;
235+
}
236+
237+
render();
238+
239+
const timeStr =
240+
gpuTimeMs !== null
241+
? ` in ${
242+
gpuTimeMs >= 1000 ? `${(gpuTimeMs / 1000).toFixed(2)}s` : `${gpuTimeMs.toFixed(2)}ms`
243+
}`
244+
: '';
245+
showOverlay(`\u2714 Sorted${timeStr}`, false);
246+
hideOverlay();
247+
}
248+
249+
// #region Example controls & Cleanup
250+
251+
const sortOrderKeys = Object.keys(sortOrders) as SortOrderKey[];
252+
253+
export const controls = defineControls({
254+
'Array Size': {
255+
initial: arraySizeOptions[2],
256+
options: arraySizeOptions,
257+
onSelectChange: (value) => {
258+
state.arraySize = isNaN(value) ? 64 : value;
259+
recreateBuffer();
260+
generateRandomArray();
261+
},
262+
},
263+
'Sort Order': {
264+
initial: 'ascending',
265+
options: sortOrderKeys,
266+
onSelectChange: (value) => {
267+
state.sortOrder = value;
268+
},
269+
},
270+
Reshuffle: { onButtonClick: generateRandomArray },
271+
Sort: { onButtonClick: sort },
272+
});
273+
274+
export function onCleanup() {
275+
for (const s of Object.values(sorters)) {
276+
s.destroy();
277+
}
278+
root.destroy();
279+
}
280+
281+
// #endregion
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"title": "Bitonic Sort",
3+
"category": "algorithms",
4+
"tags": ["experimental", "compute"]
5+
}
107 KB
Loading

apps/typegpu-docs/src/examples/algorithms/concurrent-chart/calculator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { initCache, prefixScan } from '@typegpu/concurrent-scan';
1+
import { createPrefixScanComputer, prefixScan } from '@typegpu/sort';
22
import type { TgpuRoot } from 'typegpu';
33
import { d, std } from 'typegpu';
44

@@ -45,7 +45,7 @@ export async function performCalculationsWithTime(
4545
const jsTime = performance.now() - jsStartTime;
4646

4747
// GPU version
48-
initCache(root, { operation: std.add, identityElement: 0 });
48+
createPrefixScanComputer(root, { operation: std.add, identityElement: 0 });
4949
const querySet = root.createQuerySet('timestamp', 2);
5050
const gpuStartTime = performance.now();
5151
const calcResult = prefixScan(

apps/typegpu-docs/src/examples/tests/prefix-scan/functions.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tgpu from 'typegpu';
22
import * as d from 'typegpu/data';
33
import * as std from 'typegpu/std';
4-
import type { BinaryOp } from '@typegpu/concurrent-scan';
4+
import type { BinaryOp } from '@typegpu/sort';
55

66
// tgpu functions
77

apps/typegpu-docs/src/examples/tests/prefix-scan/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tgpu from 'typegpu';
22
import * as d from 'typegpu/data';
3-
import { type BinaryOp, prefixScan, scan } from '@typegpu/concurrent-scan';
3+
import { type BinaryOp, prefixScan, scan } from '@typegpu/sort';
44
import * as std from 'typegpu/std';
55
import { addFn, concat10, isArrayEqual, mulFn, prefixScanJS, scanJS } from './functions.ts';
66

apps/typegpu-docs/src/utils/examples/sandboxModules.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ export const SANDBOX_MODULES: Record<string, SandboxModuleDefinition> = {
137137
import: { reroute: 'typegpu-color/src/index.ts' },
138138
typeDef: { reroute: 'typegpu-color/src/index.ts' },
139139
},
140-
'@typegpu/concurrent-scan': {
141-
import: { reroute: 'typegpu-concurrent-scan/src/index.ts' },
142-
typeDef: { reroute: 'typegpu-concurrent-scan/src/index.ts' },
143-
},
144140
'@typegpu/three': {
145141
typeDef: { reroute: 'typegpu-three/src/index.ts' },
146142
},
143+
'@typegpu/sort': {
144+
import: { reroute: 'typegpu-sort/src/index.ts' },
145+
typeDef: { reroute: 'typegpu-sort/src/index.ts' },
146+
},
147147
};

0 commit comments

Comments
 (0)