Skip to content

Commit 7950868

Browse files
authored
fix: Atomic and decorated types for SoA and write fast path (#2329)
1 parent 712fbd5 commit 7950868

3 files changed

Lines changed: 258 additions & 70 deletions

File tree

packages/typegpu/src/common/writeSoA.ts

Lines changed: 49 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import { invariant } from '../errors.ts';
22
import { roundUp } from '../mathUtils.ts';
3+
import type { Undecorate } from '../data/dataTypes.ts';
34
import { alignmentOf } from '../data/alignmentOf.ts';
5+
import { undecorate } from '../data/dataTypes.ts';
46
import { offsetsForProps } from '../data/offsets.ts';
57
import { sizeOf } from '../data/sizeOf.ts';
68
import type { BaseData, TypedArrayFor, WgslArray, WgslStruct } from '../data/wgslTypes.ts';
7-
import { isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts';
9+
import { isAtomic, isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts';
810
import type { BufferWriteOptions, TgpuBuffer } from '../core/buffer/buffer.ts';
911
import type { Prettify } from '../shared/utilityTypes.ts';
1012

11-
type UnwrapWgslArray<T> = T extends WgslArray<infer U> ? UnwrapWgslArray<U> : T;
12-
type PackedSoAInputFor<T> = TypedArrayFor<UnwrapWgslArray<T>>;
13+
type PackedScalarFor<T> =
14+
Undecorate<T> extends WgslArray<infer TElement> ? PackedScalarFor<TElement> : Undecorate<T>;
15+
16+
type PackedSoAInputFor<T> = TypedArrayFor<PackedScalarFor<T>>;
1317

1418
type SoAFieldsFor<T extends Record<string, BaseData>> = {
1519
[K in keyof T as [PackedSoAInputFor<T[K]>] extends [never] ? never : K]: PackedSoAInputFor<T[K]>;
@@ -19,74 +23,53 @@ type SoAInputFor<T extends Record<string, BaseData>> = [keyof T] extends [keyof
1923
? Prettify<SoAFieldsFor<T>>
2024
: never;
2125

22-
function getPackedMatrixLayout(schema: BaseData) {
23-
if (!isMat(schema)) {
24-
return undefined;
25-
}
26-
27-
const dim = isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : 4;
28-
const packedColumnSize = dim * 4;
26+
function packedSchemaOf(schema: BaseData): BaseData {
27+
const unpackedSchema = undecorate(schema);
28+
return isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema;
29+
}
2930

30-
return {
31-
dim,
32-
packedColumnSize,
33-
packedSize: dim * packedColumnSize,
34-
} as const;
31+
function packedMatrixDimOf(schema: BaseData): 2 | 3 | 4 | undefined {
32+
return isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : isMat(schema) ? 4 : undefined;
3533
}
3634

3735
function packedSizeOf(schema: BaseData): number {
38-
const matrixLayout = getPackedMatrixLayout(schema);
39-
if (matrixLayout) {
40-
return matrixLayout.packedSize;
36+
const packedSchema = packedSchemaOf(schema);
37+
const matrixDim = packedMatrixDimOf(packedSchema);
38+
if (matrixDim) {
39+
return matrixDim * matrixDim * 4;
4140
}
42-
43-
if (isWgslArray(schema)) {
44-
return schema.elementCount * packedSizeOf(schema.elementType);
41+
if (isWgslArray(packedSchema)) {
42+
return packedSchema.elementCount * packedSizeOf(packedSchema.elementType);
4543
}
46-
47-
return sizeOf(schema);
44+
return sizeOf(packedSchema);
4845
}
4946

50-
function inferSoAElementCount(
47+
function computeSoAByteLength(
5148
arraySchema: WgslArray,
5249
soaData: Record<string, ArrayBufferView>,
5350
): number | undefined {
5451
const structSchema = arraySchema.elementType as WgslStruct;
5552
let inferredCount: number | undefined;
5653

57-
for (const key in soaData) {
54+
for (const key in structSchema.propTypes) {
5855
const srcArray = soaData[key];
5956
const fieldSchema = structSchema.propTypes[key];
6057
if (srcArray === undefined || fieldSchema === undefined) {
6158
continue;
6259
}
63-
64-
const fieldPackedSize = packedSizeOf(fieldSchema);
65-
if (fieldPackedSize === 0) {
60+
const packedFieldSize = packedSizeOf(fieldSchema);
61+
if (packedFieldSize === 0) {
6662
continue;
6763
}
68-
69-
const fieldElementCount = Math.floor(srcArray.byteLength / fieldPackedSize);
64+
const fieldElementCount = Math.floor(srcArray.byteLength / packedFieldSize);
7065
inferredCount =
7166
inferredCount === undefined ? fieldElementCount : Math.min(inferredCount, fieldElementCount);
7267
}
73-
74-
return inferredCount;
75-
}
76-
77-
function computeSoAByteLength(
78-
arraySchema: WgslArray,
79-
soaData: Record<string, ArrayBufferView>,
80-
): number | undefined {
81-
const elementCount = inferSoAElementCount(arraySchema, soaData);
82-
if (elementCount === undefined) {
68+
if (inferredCount === undefined) {
8369
return undefined;
8470
}
85-
const elementStride = roundUp(
86-
sizeOf(arraySchema.elementType),
87-
alignmentOf(arraySchema.elementType),
88-
);
89-
return elementCount * elementStride;
71+
const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema));
72+
return inferredCount * elementStride;
9073
}
9174

9275
function writePackedValue(
@@ -96,41 +79,42 @@ function writePackedValue(
9679
dstOffset: number,
9780
srcOffset: number,
9881
): void {
99-
const matrixLayout = getPackedMatrixLayout(schema);
100-
if (matrixLayout) {
101-
const gpuColumnStride = roundUp(matrixLayout.packedColumnSize, alignmentOf(schema));
102-
103-
for (let col = 0; col < matrixLayout.dim; col++) {
82+
const unpackedSchema = undecorate(schema);
83+
const packedSchema = isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema;
84+
const matrixDim = packedMatrixDimOf(packedSchema);
85+
if (matrixDim) {
86+
const packedColumnSize = matrixDim * 4;
87+
const gpuColumnStride = roundUp(packedColumnSize, alignmentOf(schema));
88+
for (let col = 0; col < matrixDim; col++) {
10489
target.set(
10590
srcBytes.subarray(
106-
srcOffset + col * matrixLayout.packedColumnSize,
107-
srcOffset + col * matrixLayout.packedColumnSize + matrixLayout.packedColumnSize,
91+
srcOffset + col * packedColumnSize,
92+
srcOffset + col * packedColumnSize + packedColumnSize,
10893
),
10994
dstOffset + col * gpuColumnStride,
11095
);
11196
}
112-
11397
return;
11498
}
115-
116-
if (isWgslArray(schema)) {
117-
const packedElementSize = packedSizeOf(schema.elementType);
118-
const gpuElementStride = roundUp(sizeOf(schema.elementType), alignmentOf(schema.elementType));
119-
120-
for (let i = 0; i < schema.elementCount; i++) {
99+
if (isWgslArray(unpackedSchema)) {
100+
const packedElementSize = packedSizeOf(unpackedSchema.elementType);
101+
const gpuElementStride = roundUp(
102+
sizeOf(unpackedSchema.elementType),
103+
alignmentOf(unpackedSchema.elementType),
104+
);
105+
106+
for (let i = 0; i < unpackedSchema.elementCount; i++) {
121107
writePackedValue(
122108
target,
123-
schema.elementType,
109+
unpackedSchema.elementType,
124110
srcBytes,
125111
dstOffset + i * gpuElementStride,
126112
srcOffset + i * packedElementSize,
127113
);
128114
}
129-
130115
return;
131116
}
132-
133-
target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(schema)), dstOffset);
117+
target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(packedSchema)), dstOffset);
134118
}
135119

136120
function scatterSoA(
@@ -141,7 +125,6 @@ function scatterSoA(
141125
endOffset: number,
142126
): void {
143127
const structSchema = arraySchema.elementType as WgslStruct;
144-
const offsets = offsetsForProps(structSchema);
145128
const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema));
146129
invariant(
147130
startOffset % elementStride === 0,
@@ -150,6 +133,7 @@ function scatterSoA(
150133
const startElement = Math.floor(startOffset / elementStride);
151134
const endElement = Math.min(arraySchema.elementCount, Math.ceil(endOffset / elementStride));
152135
const elementCount = Math.max(0, endElement - startElement);
136+
const offsets = offsetsForProps(structSchema);
153137

154138
for (const key in structSchema.propTypes) {
155139
const fieldSchema = structSchema.propTypes[key];
@@ -158,12 +142,10 @@ function scatterSoA(
158142
}
159143
const srcArray = soaData[key];
160144
invariant(srcArray !== undefined, `Missing SoA data for field '${key}'`);
161-
162145
const fieldOffset = offsets[key]?.offset;
163146
invariant(fieldOffset !== undefined, `Field ${key} not found in struct schema`);
164-
const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength);
165-
166147
const packedFieldSize = packedSizeOf(fieldSchema);
148+
const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength);
167149
for (let i = 0; i < elementCount; i++) {
168150
writePackedValue(
169151
target,

packages/typegpu/src/data/wgslTypes.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ export type TypedArrayFor<T> = T extends F32 | Vec2f | Vec3f | Vec4f | Mat2x2f |
6262
? Float32Array
6363
: T extends F16 | Vec2h | Vec3h | Vec4h
6464
? Float16Array
65-
: T extends I32 | Vec2i | Vec3i | Vec4i
65+
: T extends I32 | Vec2i | Vec3i | Vec4i | Atomic<I32>
6666
? Int32Array
67-
: T extends U32 | Vec2u | Vec3u | Vec4u
67+
: T extends U32 | Vec2u | Vec3u | Vec4u | Atomic<U32>
6868
? Uint32Array
6969
: T extends U16
7070
? Uint16Array
71-
: never;
71+
: T extends Decorated<infer TBase>
72+
? TypedArrayFor<TBase>
73+
: never;
7274

7375
/**
7476
* Vector infix notation.

0 commit comments

Comments
 (0)