11import { invariant } from '../errors.ts' ;
22import { roundUp } from '../mathUtils.ts' ;
3+ import type { Undecorate } from '../data/dataTypes.ts' ;
34import { alignmentOf } from '../data/alignmentOf.ts' ;
5+ import { undecorate } from '../data/dataTypes.ts' ;
46import { offsetsForProps } from '../data/offsets.ts' ;
57import { sizeOf } from '../data/sizeOf.ts' ;
68import type { BaseData , TypedArrayFor , WgslArray , WgslStruct } from '../data/wgslTypes.ts' ;
7- import { isMat , isMat2x2f , isMat3x3f , isWgslArray } from '../data/wgslTypes.ts' ;
9+ import { isAtomic , isMat , isMat2x2f , isMat3x3f , isWgslArray } from '../data/wgslTypes.ts' ;
810import type { BufferWriteOptions , TgpuBuffer } from '../core/buffer/buffer.ts' ;
911import type { Prettify } from '../shared/utilityTypes.ts' ;
1012
11- type UnwrapWgslArray < T > = T extends WgslArray < infer U > ? UnwrapWgslArray < U > : T ;
12- type PackedSoAInputFor < T > = TypedArrayFor < UnwrapWgslArray < T > > ;
13+ type PackedScalarFor < T > =
14+ Undecorate < T > extends WgslArray < infer TElement > ? PackedScalarFor < TElement > : Undecorate < T > ;
15+
16+ type PackedSoAInputFor < T > = TypedArrayFor < PackedScalarFor < T > > ;
1317
1418type SoAFieldsFor < T extends Record < string , BaseData > > = {
1519 [ K in keyof T as [ PackedSoAInputFor < T [ K ] > ] extends [ never ] ? never : K ] : PackedSoAInputFor < T [ K ] > ;
@@ -19,74 +23,53 @@ type SoAInputFor<T extends Record<string, BaseData>> = [keyof T] extends [keyof
1923 ? Prettify < SoAFieldsFor < T > >
2024 : never ;
2125
22- function getPackedMatrixLayout ( schema : BaseData ) {
23- if ( ! isMat ( schema ) ) {
24- return undefined ;
25- }
26-
27- const dim = isMat3x3f ( schema ) ? 3 : isMat2x2f ( schema ) ? 2 : 4 ;
28- const packedColumnSize = dim * 4 ;
26+ function packedSchemaOf ( schema : BaseData ) : BaseData {
27+ const unpackedSchema = undecorate ( schema ) ;
28+ return isAtomic ( unpackedSchema ) ? unpackedSchema . inner : unpackedSchema ;
29+ }
2930
30- return {
31- dim,
32- packedColumnSize,
33- packedSize : dim * packedColumnSize ,
34- } as const ;
31+ function packedMatrixDimOf ( schema : BaseData ) : 2 | 3 | 4 | undefined {
32+ return isMat3x3f ( schema ) ? 3 : isMat2x2f ( schema ) ? 2 : isMat ( schema ) ? 4 : undefined ;
3533}
3634
3735function packedSizeOf ( schema : BaseData ) : number {
38- const matrixLayout = getPackedMatrixLayout ( schema ) ;
39- if ( matrixLayout ) {
40- return matrixLayout . packedSize ;
36+ const packedSchema = packedSchemaOf ( schema ) ;
37+ const matrixDim = packedMatrixDimOf ( packedSchema ) ;
38+ if ( matrixDim ) {
39+ return matrixDim * matrixDim * 4 ;
4140 }
42-
43- if ( isWgslArray ( schema ) ) {
44- return schema . elementCount * packedSizeOf ( schema . elementType ) ;
41+ if ( isWgslArray ( packedSchema ) ) {
42+ return packedSchema . elementCount * packedSizeOf ( packedSchema . elementType ) ;
4543 }
46-
47- return sizeOf ( schema ) ;
44+ return sizeOf ( packedSchema ) ;
4845}
4946
50- function inferSoAElementCount (
47+ function computeSoAByteLength (
5148 arraySchema : WgslArray ,
5249 soaData : Record < string , ArrayBufferView > ,
5350) : number | undefined {
5451 const structSchema = arraySchema . elementType as WgslStruct ;
5552 let inferredCount : number | undefined ;
5653
57- for ( const key in soaData ) {
54+ for ( const key in structSchema . propTypes ) {
5855 const srcArray = soaData [ key ] ;
5956 const fieldSchema = structSchema . propTypes [ key ] ;
6057 if ( srcArray === undefined || fieldSchema === undefined ) {
6158 continue ;
6259 }
63-
64- const fieldPackedSize = packedSizeOf ( fieldSchema ) ;
65- if ( fieldPackedSize === 0 ) {
60+ const packedFieldSize = packedSizeOf ( fieldSchema ) ;
61+ if ( packedFieldSize === 0 ) {
6662 continue ;
6763 }
68-
69- const fieldElementCount = Math . floor ( srcArray . byteLength / fieldPackedSize ) ;
64+ const fieldElementCount = Math . floor ( srcArray . byteLength / packedFieldSize ) ;
7065 inferredCount =
7166 inferredCount === undefined ? fieldElementCount : Math . min ( inferredCount , fieldElementCount ) ;
7267 }
73-
74- return inferredCount ;
75- }
76-
77- function computeSoAByteLength (
78- arraySchema : WgslArray ,
79- soaData : Record < string , ArrayBufferView > ,
80- ) : number | undefined {
81- const elementCount = inferSoAElementCount ( arraySchema , soaData ) ;
82- if ( elementCount === undefined ) {
68+ if ( inferredCount === undefined ) {
8369 return undefined ;
8470 }
85- const elementStride = roundUp (
86- sizeOf ( arraySchema . elementType ) ,
87- alignmentOf ( arraySchema . elementType ) ,
88- ) ;
89- return elementCount * elementStride ;
71+ const elementStride = roundUp ( sizeOf ( structSchema ) , alignmentOf ( structSchema ) ) ;
72+ return inferredCount * elementStride ;
9073}
9174
9275function writePackedValue (
@@ -96,41 +79,42 @@ function writePackedValue(
9679 dstOffset : number ,
9780 srcOffset : number ,
9881) : void {
99- const matrixLayout = getPackedMatrixLayout ( schema ) ;
100- if ( matrixLayout ) {
101- const gpuColumnStride = roundUp ( matrixLayout . packedColumnSize , alignmentOf ( schema ) ) ;
102-
103- for ( let col = 0 ; col < matrixLayout . dim ; col ++ ) {
82+ const unpackedSchema = undecorate ( schema ) ;
83+ const packedSchema = isAtomic ( unpackedSchema ) ? unpackedSchema . inner : unpackedSchema ;
84+ const matrixDim = packedMatrixDimOf ( packedSchema ) ;
85+ if ( matrixDim ) {
86+ const packedColumnSize = matrixDim * 4 ;
87+ const gpuColumnStride = roundUp ( packedColumnSize , alignmentOf ( schema ) ) ;
88+ for ( let col = 0 ; col < matrixDim ; col ++ ) {
10489 target . set (
10590 srcBytes . subarray (
106- srcOffset + col * matrixLayout . packedColumnSize ,
107- srcOffset + col * matrixLayout . packedColumnSize + matrixLayout . packedColumnSize ,
91+ srcOffset + col * packedColumnSize ,
92+ srcOffset + col * packedColumnSize + packedColumnSize ,
10893 ) ,
10994 dstOffset + col * gpuColumnStride ,
11095 ) ;
11196 }
112-
11397 return ;
11498 }
115-
116- if ( isWgslArray ( schema ) ) {
117- const packedElementSize = packedSizeOf ( schema . elementType ) ;
118- const gpuElementStride = roundUp ( sizeOf ( schema . elementType ) , alignmentOf ( schema . elementType ) ) ;
119-
120- for ( let i = 0 ; i < schema . elementCount ; i ++ ) {
99+ if ( isWgslArray ( unpackedSchema ) ) {
100+ const packedElementSize = packedSizeOf ( unpackedSchema . elementType ) ;
101+ const gpuElementStride = roundUp (
102+ sizeOf ( unpackedSchema . elementType ) ,
103+ alignmentOf ( unpackedSchema . elementType ) ,
104+ ) ;
105+
106+ for ( let i = 0 ; i < unpackedSchema . elementCount ; i ++ ) {
121107 writePackedValue (
122108 target ,
123- schema . elementType ,
109+ unpackedSchema . elementType ,
124110 srcBytes ,
125111 dstOffset + i * gpuElementStride ,
126112 srcOffset + i * packedElementSize ,
127113 ) ;
128114 }
129-
130115 return ;
131116 }
132-
133- target . set ( srcBytes . subarray ( srcOffset , srcOffset + sizeOf ( schema ) ) , dstOffset ) ;
117+ target . set ( srcBytes . subarray ( srcOffset , srcOffset + sizeOf ( packedSchema ) ) , dstOffset ) ;
134118}
135119
136120function scatterSoA (
@@ -141,7 +125,6 @@ function scatterSoA(
141125 endOffset : number ,
142126) : void {
143127 const structSchema = arraySchema . elementType as WgslStruct ;
144- const offsets = offsetsForProps ( structSchema ) ;
145128 const elementStride = roundUp ( sizeOf ( structSchema ) , alignmentOf ( structSchema ) ) ;
146129 invariant (
147130 startOffset % elementStride === 0 ,
@@ -150,6 +133,7 @@ function scatterSoA(
150133 const startElement = Math . floor ( startOffset / elementStride ) ;
151134 const endElement = Math . min ( arraySchema . elementCount , Math . ceil ( endOffset / elementStride ) ) ;
152135 const elementCount = Math . max ( 0 , endElement - startElement ) ;
136+ const offsets = offsetsForProps ( structSchema ) ;
153137
154138 for ( const key in structSchema . propTypes ) {
155139 const fieldSchema = structSchema . propTypes [ key ] ;
@@ -158,12 +142,10 @@ function scatterSoA(
158142 }
159143 const srcArray = soaData [ key ] ;
160144 invariant ( srcArray !== undefined , `Missing SoA data for field '${ key } '` ) ;
161-
162145 const fieldOffset = offsets [ key ] ?. offset ;
163146 invariant ( fieldOffset !== undefined , `Field ${ key } not found in struct schema` ) ;
164- const srcBytes = new Uint8Array ( srcArray . buffer , srcArray . byteOffset , srcArray . byteLength ) ;
165-
166147 const packedFieldSize = packedSizeOf ( fieldSchema ) ;
148+ const srcBytes = new Uint8Array ( srcArray . buffer , srcArray . byteOffset , srcArray . byteLength ) ;
167149 for ( let i = 0 ; i < elementCount ; i ++ ) {
168150 writePackedValue (
169151 target ,
0 commit comments