Skip to content

Commit 7de6653

Browse files
authored
feat: Add Indirect buffer usage and support .dispatchWorkgroupsIndirect API (#2105)
1 parent 1b02ad6 commit 7de6653

17 files changed

Lines changed: 1768 additions & 184 deletions

File tree

apps/typegpu-docs/src/examples/algorithms/mnist-inference/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
1919
const hasSubgroups = root.enabledFeatures.has('subgroups');
2020
let useSubgroups = hasSubgroups;
2121

22-
const canvasData = Array.from({ length: (SIZE ** 2) }, () => 0);
22+
const canvasData = Array.from({ length: SIZE ** 2 }, () => 0);
2323

2424
// Shaders
2525

@@ -314,7 +314,7 @@ function centerImage(data: number[]) {
314314
const offsetX = Math.round(SIZE / 2 - x);
315315
const offsetY = Math.round(SIZE / 2 - y);
316316

317-
const newData = Array.from({ length: (SIZE * SIZE) }, () => 0);
317+
const newData = Array.from({ length: SIZE * SIZE }, () => 0);
318318
for (let i = 0; i < SIZE; i++) {
319319
for (let j = 0; j < SIZE; j++) {
320320
const index = i * SIZE + j;

packages/typegpu-three/src/typegpu-node.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ export const fromTSL = tgpu.comptime(
397397
let nodeType: string | null = null;
398398
try { // sometimes it needs information (overrideNodes) from compilation context which is not present
399399
nodeType = node.getNodeType(sharedBuilder);
400-
} catch (e) {
400+
} catch {
401401
console.warn(
402402
`fromTSL: failed to infer node type via getNodeType; skipping type comparison.`,
403403
);

packages/typegpu/src/core/buffer/buffer.ts

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,24 @@ export interface IndexFlag {
5656
usableAsIndex: true;
5757
}
5858

59+
export interface IndirectFlag {
60+
usableAsIndirect: true;
61+
}
62+
5963
/**
6064
* @deprecated Use VertexFlag instead.
6165
*/
6266
export type Vertex = VertexFlag;
6367

64-
type LiteralToUsageType<T extends 'uniform' | 'storage' | 'vertex' | 'index'> =
65-
T extends 'uniform' ? UniformFlag
66-
: T extends 'storage' ? StorageFlag
67-
: T extends 'vertex' ? VertexFlag
68-
: T extends 'index' ? IndexFlag
69-
: never;
68+
type UsageLiteral = 'uniform' | 'storage' | 'vertex' | 'index' | 'indirect';
69+
70+
type LiteralToUsageType<T extends UsageLiteral> = T extends 'uniform'
71+
? UniformFlag
72+
: T extends 'storage' ? StorageFlag
73+
: T extends 'vertex' ? VertexFlag
74+
: T extends 'index' ? IndexFlag
75+
: T extends 'indirect' ? IndirectFlag
76+
: never;
7077

7178
type ViewUsages<TBuffer extends TgpuBuffer<BaseData>> =
7279
| (boolean extends TBuffer['usableAsUniform'] ? never : 'uniform')
@@ -89,7 +96,9 @@ type InnerValidUsagesFor<T> = {
8996
| (IsValidStorageSchema<T> extends true ? 'storage' : never)
9097
| (IsValidUniformSchema<T> extends true ? 'uniform' : never)
9198
| (IsValidVertexSchema<T> extends true ? 'vertex' : never)
92-
| (IsValidIndexSchema<T> extends true ? 'index' : never);
99+
| (IsValidIndexSchema<T> extends true ? 'index' : never)
100+
// there is no way to check at the type level if a buffer can be used as indirect (size >= 12 bytes)
101+
| 'indirect';
93102
};
94103

95104
export type ValidUsagesFor<T> = InnerValidUsagesFor<T>['usage'];
@@ -107,6 +116,7 @@ export interface TgpuBuffer<TData extends BaseData> extends TgpuNamable {
107116
usableAsStorage: boolean;
108117
usableAsVertex: boolean;
109118
usableAsIndex: boolean;
119+
usableAsIndirect: boolean;
110120

111121
$usage<
112122
T extends [
@@ -184,13 +194,13 @@ class TgpuBufferImpl<TData extends BaseData> implements TgpuBuffer<TData> {
184194
usableAsStorage = false;
185195
usableAsVertex = false;
186196
usableAsIndex = false;
197+
usableAsIndirect = false;
187198

188199
constructor(
189200
root: ExperimentalTgpuRoot,
190201
public readonly dataType: TData,
191202
public readonly initialOrBuffer?: Infer<TData> | GPUBuffer,
192-
private readonly _disallowedUsages?:
193-
('uniform' | 'storage' | 'vertex' | 'index')[],
203+
private readonly _disallowedUsages?: UsageLiteral[],
194204
) {
195205
this.#device = root.device;
196206
if (isGPUBuffer(initialOrBuffer)) {
@@ -236,7 +246,7 @@ class TgpuBufferImpl<TData extends BaseData> implements TgpuBuffer<TData> {
236246
return this;
237247
}
238248

239-
$usage<T extends ('uniform' | 'storage' | 'vertex' | 'index')[]>(
249+
$usage<T extends UsageLiteral[]>(
240250
...usages: T
241251
): this & UnionToIntersection<LiteralToUsageType<T[number]>> {
242252
for (const usage of usages) {
@@ -250,10 +260,12 @@ class TgpuBufferImpl<TData extends BaseData> implements TgpuBuffer<TData> {
250260
this.flags |= usage === 'storage' ? GPUBufferUsage.STORAGE : 0;
251261
this.flags |= usage === 'vertex' ? GPUBufferUsage.VERTEX : 0;
252262
this.flags |= usage === 'index' ? GPUBufferUsage.INDEX : 0;
263+
this.flags |= usage === 'indirect' ? GPUBufferUsage.INDIRECT : 0;
253264
this.usableAsUniform = this.usableAsUniform || usage === 'uniform';
254265
this.usableAsStorage = this.usableAsStorage || usage === 'storage';
255266
this.usableAsVertex = this.usableAsVertex || usage === 'vertex';
256267
this.usableAsIndex = this.usableAsIndex || usage === 'index';
268+
this.usableAsIndirect = this.usableAsIndirect || usage === 'indirect';
257269
}
258270
return this as this & UnionToIntersection<LiteralToUsageType<T[number]>>;
259271
}

packages/typegpu/src/core/pipeline/computePipeline.ts

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import type { AnyComputeBuiltin } from '../../builtin.ts';
22
import type { TgpuQuerySet } from '../../core/querySet/querySet.ts';
33
import { type ResolvedSnippet, snip } from '../../data/snippet.ts';
4+
import { sizeOf } from '../../data/sizeOf.ts';
5+
import type { AnyWgslData } from '../../data/wgslTypes.ts';
46
import { Void } from '../../data/wgslTypes.ts';
57
import { MissingBindGroupsError } from '../../errors.ts';
68
import { type ResolutionResult, resolve } from '../../resolutionCtx.ts';
79
import type { TgpuNamable } from '../../shared/meta.ts';
810
import { getName, PERF, setName } from '../../shared/meta.ts';
11+
912
import { $getNameForward, $internal, $resolve } from '../../shared/symbols.ts';
1013
import {
1114
isBindGroup,
@@ -26,6 +29,10 @@ import { namespace } from '../resolve/namespace.ts';
2629
import type { ExperimentalTgpuRoot } from '../root/rootTypes.ts';
2730
import type { TgpuSlot } from '../slot/slotTypes.ts';
2831
import { warnIfOverflow } from './limitsOverflow.ts';
32+
import {
33+
getOffsetInfoAt,
34+
type PrimitiveOffsetInfo,
35+
} from '../../data/offsetUtils.ts';
2936
import {
3037
createWithPerformanceCallback,
3138
createWithTimestampWrites,
@@ -34,6 +41,7 @@ import {
3441
type TimestampWritesPriors,
3542
triggerPerformanceCallback,
3643
} from './timeable.ts';
44+
import type { IndirectFlag, TgpuBuffer } from '../buffer/buffer.ts';
3745

3846
interface ComputePipelineInternals {
3947
readonly rawPipeline: GPUComputePipeline;
@@ -65,6 +73,19 @@ export interface TgpuComputePipeline
6573
y?: number,
6674
z?: number,
6775
): void;
76+
77+
/**
78+
* Dispatches compute workgroups using parameters read from a buffer.
79+
* The buffer must contain 3 consecutive u32 values (x, y, z workgroup counts).
80+
* To get the correct offset within complex data structures, use `d.getOffsetInfoAt(...)`.
81+
*
82+
* @param indirectBuffer - Buffer marked with 'indirect' usage containing dispatch parameters
83+
* @param start - PrimitiveOffsetInfo pointing to the first dispatch parameter. If not provided, starts at offset 0. To obtain safe offsets, use `getOffsetInfoAt(...)`.
84+
*/
85+
dispatchWorkgroupsIndirect<T extends AnyWgslData>(
86+
indirectBuffer: TgpuBuffer<T> & IndirectFlag,
87+
start?: PrimitiveOffsetInfo | number,
88+
): void;
6889
}
6990

7091
export declare namespace TgpuComputePipeline {
@@ -101,6 +122,27 @@ type Memo = {
101122
logResources: LogResources | undefined;
102123
};
103124

125+
function validateIndirectBufferSize(
126+
bufferSize: number,
127+
offset: number,
128+
requiredBytes: number,
129+
operation: string,
130+
): void {
131+
if (offset + requiredBytes > bufferSize) {
132+
throw new Error(
133+
`Buffer too small for ${operation}. ` +
134+
`Required: ${requiredBytes} bytes at offset ${offset}, ` +
135+
`but buffer is only ${bufferSize} bytes.`,
136+
);
137+
}
138+
139+
if (offset % 4 !== 0) {
140+
throw new Error(
141+
`Indirect buffer offset must be a multiple of 4. Got: ${offset}`,
142+
);
143+
}
144+
}
145+
104146
class TgpuComputePipelineImpl implements TgpuComputePipeline {
105147
public readonly [$internal]: ComputePipelineInternals;
106148
public readonly resourceType = 'compute-pipeline';
@@ -192,6 +234,55 @@ class TgpuComputePipelineImpl implements TgpuComputePipeline {
192234
x: number,
193235
y?: number,
194236
z?: number,
237+
): void {
238+
this._executeComputePass((pass) => pass.dispatchWorkgroups(x, y, z));
239+
}
240+
241+
dispatchWorkgroupsIndirect<T extends AnyWgslData>(
242+
indirectBuffer: TgpuBuffer<T> & IndirectFlag,
243+
start?: PrimitiveOffsetInfo | number,
244+
): void {
245+
const DISPATCH_SIZE = 12; // 3 x u32 (x, y, z)
246+
247+
let offsetInfo = start ?? getOffsetInfoAt(indirectBuffer.dataType);
248+
249+
if (typeof offsetInfo === 'number') {
250+
if (offsetInfo === 0) {
251+
offsetInfo = getOffsetInfoAt(indirectBuffer.dataType);
252+
} else {
253+
console.warn(
254+
`dispatchWorkgroupsIndirect: Provided start offset ${offsetInfo} as a raw number. Use getOffsetInfoAt(...) to include contiguous padding info for safer validation.`,
255+
);
256+
// When only an offset is provided, assume we have at least 12 bytes contiguous.
257+
offsetInfo = {
258+
offset: offsetInfo,
259+
contiguous: DISPATCH_SIZE,
260+
};
261+
}
262+
}
263+
264+
const { offset, contiguous } = offsetInfo;
265+
266+
validateIndirectBufferSize(
267+
sizeOf(indirectBuffer.dataType),
268+
offset,
269+
DISPATCH_SIZE,
270+
'dispatchWorkgroupsIndirect',
271+
);
272+
273+
if (contiguous < DISPATCH_SIZE) {
274+
console.warn(
275+
`dispatchWorkgroupsIndirect: Starting at offset ${offset}, only ${contiguous} contiguous bytes are available before padding. Dispatch requires ${DISPATCH_SIZE} bytes (3 x u32). Reading across padding may result in undefined behavior.`,
276+
);
277+
}
278+
279+
this._executeComputePass((pass) =>
280+
pass.dispatchWorkgroupsIndirect(indirectBuffer.buffer, offset)
281+
);
282+
}
283+
284+
private _executeComputePass(
285+
dispatch: (pass: GPUComputePassEncoder) => void,
195286
): void {
196287
const memo = this._core.unwrap();
197288
const { root } = this._core;
@@ -231,7 +322,7 @@ class TgpuComputePipelineImpl implements TgpuComputePipeline {
231322
throw new MissingBindGroupsError(missingBindGroups);
232323
}
233324

234-
pass.dispatchWorkgroups(x, y, z);
325+
dispatch(pass);
235326
pass.end();
236327
root.device.queue.submit([commandEncoder.finish()]);
237328

packages/typegpu/src/data/alignmentOf.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ const knownAlignmentMap: Record<string, number> = {
2525
vec2h: 4,
2626
vec2i: 8,
2727
vec2u: 8,
28-
vec2b: 8,
28+
'vec2<bool>': 8,
2929
vec3f: 16,
3030
vec3h: 8,
3131
vec3i: 16,
3232
vec3u: 16,
33-
vec3b: 16,
33+
'vec3<bool>': 16,
3434
vec4f: 16,
3535
vec4h: 8,
3636
vec4i: 16,
3737
vec4u: 16,
38-
vec4b: 16,
38+
'vec4<bool>': 16,
3939
mat2x2f: 8,
4040
mat3x3f: 16,
4141
mat4x4f: 16,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import type { AnyData } from './dataTypes.ts';
2+
import type { BaseData } from './wgslTypes.ts';
3+
import { getLayoutInfo } from './schemaMemoryLayout.ts';
4+
5+
export function getLongestContiguousPrefix(schema: BaseData): number {
6+
return getLayoutInfo(schema, 'longestContiguousPrefix');
7+
}
8+
9+
/**
10+
* Returns the size (in bytes) of the longest contiguous memory prefix of data represented by the `schema`.
11+
*/
12+
export function PUBLIC_getLongestContiguousPrefix(schema: AnyData): number {
13+
return getLongestContiguousPrefix(schema);
14+
}

packages/typegpu/src/data/index.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ export {
211211
isUnstruct,
212212
} from './dataTypes.ts';
213213
export { PUBLIC_sizeOf as sizeOf } from './sizeOf.ts';
214+
export { PUBLIC_isContiguous as isContiguous } from './isContiguous.ts';
215+
export {
216+
PUBLIC_getLongestContiguousPrefix as getLongestContiguousPrefix,
217+
} from './getLongestContiguousPrefix.ts';
218+
export { getOffsetInfoAt } from './offsetUtils.ts';
214219
export { PUBLIC_alignmentOf as alignmentOf } from './alignmentOf.ts';
215220
export { builtin } from '../builtin.ts';
216221
export { deepEqual } from './deepEqual.ts';
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import type { AnyData } from './dataTypes.ts';
2+
import type { BaseData } from './wgslTypes.ts';
3+
import { getLayoutInfo } from './schemaMemoryLayout.ts';
4+
5+
export function isContiguous(schema: BaseData): boolean {
6+
return getLayoutInfo(schema, 'isContiguous');
7+
}
8+
9+
/**
10+
* Returns `true` if data represented by the `schema` doesn't have padding.
11+
*/
12+
export function PUBLIC_isContiguous(schema: AnyData): boolean {
13+
return isContiguous(schema);
14+
}

0 commit comments

Comments
 (0)