@@ -17,16 +17,25 @@ interface TgpuFnNodeData extends THREE.NodeData {
1717 } ;
1818}
1919
20- class StageData {
20+ abstract class StageData {
21+ declare readonly type : 'analyze' | 'generate' ;
2122 readonly stage : 'vertex' | 'fragment' | 'compute' | null ;
22- readonly names : WeakMap < object , string > ;
2323 readonly namespace : Namespace ;
24- codeGeneratedThusFar : string ;
2524
2625 constructor ( stage : 'vertex' | 'fragment' | 'compute' | null ) {
2726 this . stage = stage ;
28- this . names = new WeakMap ( ) ;
2927 this . namespace = tgpu [ '~unstable' ] . namespace ( ) ;
28+ }
29+ }
30+
31+ class GenerateStageData extends StageData {
32+ readonly names : WeakMap < object , string > ;
33+ readonly type = 'generate' ;
34+ codeGeneratedThusFar : string ;
35+
36+ constructor ( stage : 'vertex' | 'fragment' | 'compute' | null ) {
37+ super ( stage ) ;
38+ this . names = new WeakMap ( ) ;
3039 this . codeGeneratedThusFar = '' ;
3140
3241 this . namespace . on ( 'name' , ( event ) => {
@@ -37,18 +46,43 @@ class StageData {
3746 }
3847}
3948
49+ class AnalyzeStageData extends StageData {
50+ readonly type = 'analyze' ;
51+ }
52+
4053class BuilderData {
41- stageDataMap : Map < 'vertex' | 'fragment' | 'compute' | null , StageData > ;
54+ generateStageDataMap : Map <
55+ 'vertex' | 'fragment' | 'compute' | null ,
56+ GenerateStageData
57+ > ;
58+ analyzeStageDataMap : Map <
59+ 'vertex' | 'fragment' | 'compute' | null ,
60+ AnalyzeStageData
61+ > ;
4262
4363 constructor ( ) {
44- this . stageDataMap = new Map ( ) ;
64+ this . generateStageDataMap = new Map ( ) ;
65+ this . analyzeStageDataMap = new Map ( ) ;
66+ }
67+
68+ getGenerateStageData (
69+ stage : 'vertex' | 'fragment' | 'compute' | null ,
70+ ) : GenerateStageData {
71+ let stageData = this . generateStageDataMap . get ( stage ) ;
72+ if ( ! stageData ) {
73+ stageData = new GenerateStageData ( stage ) ;
74+ this . generateStageDataMap . set ( stage , stageData ) ;
75+ }
76+ return stageData ;
4577 }
4678
47- getStageData ( stage : 'vertex' | 'fragment' | 'compute' | null ) : StageData {
48- let stageData = this . stageDataMap . get ( stage ) ;
79+ getAnalyzeStageData (
80+ stage : 'vertex' | 'fragment' | 'compute' | null ,
81+ ) : AnalyzeStageData {
82+ let stageData = this . analyzeStageDataMap . get ( stage ) ;
4983 if ( ! stageData ) {
50- stageData = new StageData ( stage ) ;
51- this . stageDataMap . set ( stage , stageData ) ;
84+ stageData = new AnalyzeStageData ( stage ) ;
85+ this . analyzeStageDataMap . set ( stage , stageData ) ;
5286 }
5387 return stageData ;
5488 }
@@ -108,7 +142,7 @@ class TgpuFnNode<T> extends THREE.Node {
108142 builderDataMap . set ( builder , builderData ) ;
109143 }
110144
111- const stageData = builderData . getStageData ( builder . shaderStage ) ;
145+ const stageData = builderData . getGenerateStageData ( builder . shaderStage ) ;
112146
113147 if ( ! nodeData . custom ) {
114148 if ( currentlyGeneratingFnNodeCtx !== undefined ) {
@@ -163,6 +197,42 @@ class TgpuFnNode<T> extends THREE.Node {
163197 return nodeData . custom . nodeFunction ;
164198 }
165199
200+ #analyzeFunction( builder : THREE . NodeBuilder ) {
201+ let builderData = builderDataMap . get ( builder ) ;
202+
203+ if ( ! builderData ) {
204+ builderData = new BuilderData ( ) ;
205+ builderDataMap . set ( builder , builderData ) ;
206+ }
207+
208+ const stageData = builderData . getAnalyzeStageData ( builder . shaderStage ) ;
209+
210+ const ctx : TgpuFnNodeContext = {
211+ builder,
212+ stageData,
213+ dependencies : [ ] ,
214+ } ;
215+ currentlyGeneratingFnNodeCtx = ctx ;
216+ try {
217+ tgpu . resolve ( {
218+ names : stageData . namespace ,
219+ template : '___ID___ fnName' ,
220+ externals : { fnName : this . #impl } ,
221+ } ) ;
222+ } finally {
223+ currentlyGeneratingFnNodeCtx = undefined ;
224+ }
225+ }
226+
227+ /**
228+ * Replicating Three.js `analyze` traversal.
229+ * Setting `needsInterpolation` flag to true in varying nodes
230+ */
231+ analyze ( builder : THREE . NodeBuilder , output ?: THREE . Node | null ) {
232+ super . analyze ( builder , output ) ;
233+ this . #analyzeFunction( builder ) ; // making sure it will find all TSL accessors
234+ }
235+
166236 generate (
167237 builder : THREE . NodeBuilder ,
168238 output : string | null | undefined ,
@@ -171,15 +241,16 @@ class TgpuFnNode<T> extends THREE.Node {
171241
172242 const nodeData = builder . getDataFromNode ( this ) as TgpuFnNodeData ;
173243 const builderData = builderDataMap . get ( builder ) as BuilderData ;
174- const stageData = builderData . getStageData ( builder . shaderStage ) ;
244+ const stageData = builderData . getGenerateStageData ( builder . shaderStage ) ;
175245
176246 // Building dependencies
177- for ( const dep of nodeData . custom . dependencies ) {
247+ const uniqueDeps = [ ...new Set ( nodeData . custom . dependencies ) ] ;
248+ for ( const dep of uniqueDeps ) {
178249 dep . node . build ( builder ) ;
179250 }
180251 nodeData . custom . priorCode . build ( builder ) ;
181252
182- for ( const dep of nodeData . custom . dependencies ) {
253+ for ( const dep of uniqueDeps ) {
183254 if ( ! dep . var ) {
184255 continue ;
185256 }
@@ -190,10 +261,9 @@ class TgpuFnNode<T> extends THREE.Node {
190261 builder . addLineFlowCode ( `${ varName } = ${ varValue } ;\n` , this ) ;
191262 }
192263
193- if ( output === 'property' ) {
194- return nodeData . custom . functionId ;
195- }
196- return `${ nodeData . custom . functionId } ()` ;
264+ return output === 'property'
265+ ? nodeData . custom . functionId
266+ : `${ nodeData . custom . functionId } ()` ;
197267 }
198268}
199269
@@ -206,7 +276,7 @@ export function toTSL(
206276export class TSLAccessor < T extends d . AnyWgslData , TNode extends THREE . Node > {
207277 readonly #dataType: T ;
208278
209- readonly var : TgpuVar < 'private' , T > | undefined ;
279+ # var: TgpuVar < 'private' , T > | undefined ;
210280 readonly node : THREE . TSL . NodeObject < TNode > ;
211281
212282 constructor (
@@ -216,31 +286,54 @@ export class TSLAccessor<T extends d.AnyWgslData, TNode extends THREE.Node> {
216286 this . node = node ;
217287 this . #dataType = dataType ;
218288
219- // node.isTextureNode - temporary workaround for textures
220289 if (
221- // @ts -expect-error: The properties exist on the node
222- ( ! node . isStorageBufferNode && ! node . isUniformNode ) || node . isTextureNode
290+ // @ts -expect-error: they are assigned at runtime
291+ ( ! node . isStorageBufferNode && ! node . isUniformNode ) ||
292+ // @ts -expect-error: it is assigned at runtime
293+ node . isTextureNode
223294 ) {
224- this . var = tgpu . privateVar ( dataType ) ;
295+ this . # var = tgpu . privateVar ( dataType ) ;
225296 }
226297 }
227298
299+ get var ( ) : TgpuVar < 'private' , T > | undefined {
300+ return this . #var;
301+ }
302+
228303 get $ ( ) : d . InferGPU < T > {
229304 const ctx = currentlyGeneratingFnNodeCtx ;
230305
231306 if ( ! ctx ) {
232307 throw new Error ( 'Can only access TSL nodes on the GPU.' ) ;
233308 }
234309
310+ if ( ctx . stageData . type === 'analyze' ) {
311+ this . node . traverse ( ( node : THREE . Node ) => {
312+ node . analyze ( ctx . builder ) ;
313+ } ) ;
314+ // dummy return, only for types to match
315+ return tgpu [ '~unstable' ] . rawCodeSnippet ( '' , this . #dataType, 'runtime' ) . $ ;
316+ }
317+
235318 // biome-ignore lint/suspicious/noExplicitAny: smh
236319 ctx . dependencies . push ( this as any ) ;
237320
321+ const builtNode = this . node . build ( ctx . builder ) as string ;
322+
323+ // @ts -expect-error: it is assigned at runtime
324+ const trueVaryingNode = this . node . isVaryingNode &&
325+ builtNode . includes ( 'varyings.' ) ;
326+
327+ if ( trueVaryingNode ) {
328+ this . #var = undefined ; // cannot be checked earlier, ThreeJS is lazy
329+ }
330+
238331 if ( this . var ) {
239332 return this . var . $ ;
240333 }
241334
242335 return tgpu [ '~unstable' ] . rawCodeSnippet (
243- this . node . build ( ctx . builder ) as string ,
336+ builtNode ,
244337 this . #dataType,
245338 ) . $ ;
246339 }
@@ -295,8 +388,14 @@ export const fromTSL = tgpu['~unstable'].comptime<
295388 if ( ! sharedBuilder ) {
296389 sharedBuilder = new WGSLNodeBuilder ( ) ;
297390 }
298- const nodeType = node . getNodeType ( sharedBuilder ) ;
299-
391+ let nodeType : string | null = null ;
392+ try { // sometimes it needs information (overrideNodes) from compilation context which is not present
393+ nodeType = node . getNodeType ( sharedBuilder ) ;
394+ } catch ( e ) {
395+ console . warn (
396+ `fromTSL: failed to infer node type via getNodeType; skipping type comparison.` ,
397+ ) ;
398+ }
300399 if ( nodeType ) {
301400 const wgslTypeFromTSL = sharedBuilder . getType ( nodeType ) ;
302401 if ( wgslTypeFromTSL !== wgslTypeFromTgpu ) {
0 commit comments