Skip to content

Commit f1279f3

Browse files
cieplypolariwoplaza
authored andcommitted
fix(@typegpu/three): Proper handling of varying nodes (#2125)
1 parent 9e7837e commit f1279f3

5 files changed

Lines changed: 230 additions & 26 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
<canvas></canvas>
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import * as THREE from 'three/webgpu';
2+
import * as TSL from 'three/tsl';
3+
import * as t3 from '@typegpu/three';
4+
import * as d from 'typegpu/data';
5+
import * as std from 'typegpu/std';
6+
7+
const canvas = document.querySelector('canvas') as HTMLCanvasElement;
8+
const renderer = new THREE.WebGPURenderer({ canvas, antialias: true });
9+
renderer.setSize(canvas.clientWidth, canvas.clientHeight, false);
10+
renderer.setPixelRatio(window.devicePixelRatio);
11+
await renderer.init();
12+
13+
const scene = new THREE.Scene();
14+
scene.background = new THREE.Color(0xe8abbf);
15+
16+
const camera = new THREE.PerspectiveCamera(
17+
45,
18+
canvas.clientWidth / canvas.clientHeight,
19+
0.1,
20+
100,
21+
);
22+
camera.position.set(0, 7, 7);
23+
camera.lookAt(0, 0, 0);
24+
25+
const vNormal = TSL.varying(TSL.vec3(), 'vNormal');
26+
const vNormalAccessor = t3.fromTSL(vNormal, d.vec3f);
27+
const posAccessor = t3.fromTSL(TSL.positionLocal, d.vec3f);
28+
29+
const updateNormal = (newNormal: d.v3f) => {
30+
'use gpu';
31+
vNormalAccessor.$.x = newNormal.x;
32+
vNormalAccessor.$.y = newNormal.y;
33+
vNormalAccessor.$.z = newNormal.z;
34+
};
35+
36+
const positionNode = t3.toTSL(() => {
37+
'use gpu';
38+
const frequency = d.f32(3.0);
39+
const amplitude = 0.5;
40+
const wave = std.sin(posAccessor.$.x * frequency + t3.time.$);
41+
42+
posAccessor.$.y += wave * amplitude;
43+
44+
const derivative = std.cos(posAccessor.$.x * frequency + t3.time.$) *
45+
amplitude * frequency;
46+
47+
const newNormalLocal = d.vec3f(-derivative, 1.0, 0);
48+
49+
updateNormal(newNormalLocal);
50+
51+
return d.vec3f(posAccessor.$);
52+
});
53+
54+
const transformedNormalAccessor = t3.fromTSL(
55+
TSL.transformNormalToView(vNormal),
56+
d.vec3f,
57+
);
58+
59+
const normalNode = t3.toTSL(() => {
60+
'use gpu';
61+
return std.normalize(transformedNormalAccessor.$);
62+
});
63+
64+
const material = new THREE.MeshStandardNodeMaterial({
65+
color: 0x9e0d3b,
66+
roughness: 0.1,
67+
metalness: 0.8,
68+
side: THREE.DoubleSide,
69+
});
70+
71+
material.positionNode = positionNode;
72+
material.normalNode = normalNode;
73+
74+
const geometry = new THREE.PlaneGeometry(4, 4, 100, 100);
75+
geometry.rotateX(-Math.PI / 2);
76+
77+
const mesh = new THREE.Mesh(geometry, material);
78+
scene.add(mesh);
79+
80+
const dirLight = new THREE.DirectionalLight(0xffffff, 3);
81+
dirLight.position.set(-7, 10, 0);
82+
scene.add(dirLight);
83+
scene.add(new THREE.AmbientLight(0x444444));
84+
85+
renderer.setAnimationLoop(() => {
86+
renderer.render(scene, camera);
87+
});
88+
89+
const resizeObserver = new ResizeObserver(() => {
90+
camera.aspect = canvas.clientWidth / canvas.clientHeight;
91+
camera.updateProjectionMatrix();
92+
renderer.setSize(canvas.clientWidth, canvas.clientHeight, false);
93+
});
94+
resizeObserver.observe(canvas);
95+
96+
export function onCleanup() {
97+
resizeObserver.disconnect();
98+
renderer.dispose();
99+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"title": "Three.js - Varyings",
3+
"category": "threejs",
4+
"tags": ["three.js"]
5+
}
130 KB
Loading

packages/typegpu-three/src/typegpu-node.ts

Lines changed: 125 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,25 @@ interface TgpuFnNodeData extends THREE.NodeData {
1717
};
1818
}
1919

20-
class StageData {
20+
abstract class StageData {
21+
declare readonly type: 'analyze' | 'generate';
2122
readonly stage: 'vertex' | 'fragment' | 'compute' | null;
22-
readonly names: WeakMap<object, string>;
2323
readonly namespace: Namespace;
24-
codeGeneratedThusFar: string;
2524

2625
constructor(stage: 'vertex' | 'fragment' | 'compute' | null) {
2726
this.stage = stage;
28-
this.names = new WeakMap();
2927
this.namespace = tgpu['~unstable'].namespace();
28+
}
29+
}
30+
31+
class GenerateStageData extends StageData {
32+
readonly names: WeakMap<object, string>;
33+
readonly type = 'generate';
34+
codeGeneratedThusFar: string;
35+
36+
constructor(stage: 'vertex' | 'fragment' | 'compute' | null) {
37+
super(stage);
38+
this.names = new WeakMap();
3039
this.codeGeneratedThusFar = '';
3140

3241
this.namespace.on('name', (event) => {
@@ -37,18 +46,43 @@ class StageData {
3746
}
3847
}
3948

49+
class AnalyzeStageData extends StageData {
50+
readonly type = 'analyze';
51+
}
52+
4053
class BuilderData {
41-
stageDataMap: Map<'vertex' | 'fragment' | 'compute' | null, StageData>;
54+
generateStageDataMap: Map<
55+
'vertex' | 'fragment' | 'compute' | null,
56+
GenerateStageData
57+
>;
58+
analyzeStageDataMap: Map<
59+
'vertex' | 'fragment' | 'compute' | null,
60+
AnalyzeStageData
61+
>;
4262

4363
constructor() {
44-
this.stageDataMap = new Map();
64+
this.generateStageDataMap = new Map();
65+
this.analyzeStageDataMap = new Map();
66+
}
67+
68+
getGenerateStageData(
69+
stage: 'vertex' | 'fragment' | 'compute' | null,
70+
): GenerateStageData {
71+
let stageData = this.generateStageDataMap.get(stage);
72+
if (!stageData) {
73+
stageData = new GenerateStageData(stage);
74+
this.generateStageDataMap.set(stage, stageData);
75+
}
76+
return stageData;
4577
}
4678

47-
getStageData(stage: 'vertex' | 'fragment' | 'compute' | null): StageData {
48-
let stageData = this.stageDataMap.get(stage);
79+
getAnalyzeStageData(
80+
stage: 'vertex' | 'fragment' | 'compute' | null,
81+
): AnalyzeStageData {
82+
let stageData = this.analyzeStageDataMap.get(stage);
4983
if (!stageData) {
50-
stageData = new StageData(stage);
51-
this.stageDataMap.set(stage, stageData);
84+
stageData = new AnalyzeStageData(stage);
85+
this.analyzeStageDataMap.set(stage, stageData);
5286
}
5387
return stageData;
5488
}
@@ -108,7 +142,7 @@ class TgpuFnNode<T> extends THREE.Node {
108142
builderDataMap.set(builder, builderData);
109143
}
110144

111-
const stageData = builderData.getStageData(builder.shaderStage);
145+
const stageData = builderData.getGenerateStageData(builder.shaderStage);
112146

113147
if (!nodeData.custom) {
114148
if (currentlyGeneratingFnNodeCtx !== undefined) {
@@ -163,6 +197,42 @@ class TgpuFnNode<T> extends THREE.Node {
163197
return nodeData.custom.nodeFunction;
164198
}
165199

200+
#analyzeFunction(builder: THREE.NodeBuilder) {
201+
let builderData = builderDataMap.get(builder);
202+
203+
if (!builderData) {
204+
builderData = new BuilderData();
205+
builderDataMap.set(builder, builderData);
206+
}
207+
208+
const stageData = builderData.getAnalyzeStageData(builder.shaderStage);
209+
210+
const ctx: TgpuFnNodeContext = {
211+
builder,
212+
stageData,
213+
dependencies: [],
214+
};
215+
currentlyGeneratingFnNodeCtx = ctx;
216+
try {
217+
tgpu.resolve({
218+
names: stageData.namespace,
219+
template: '___ID___ fnName',
220+
externals: { fnName: this.#impl },
221+
});
222+
} finally {
223+
currentlyGeneratingFnNodeCtx = undefined;
224+
}
225+
}
226+
227+
/**
228+
* Replicating Three.js `analyze` traversal.
229+
* Setting `needsInterpolation` flag to true in varying nodes
230+
*/
231+
analyze(builder: THREE.NodeBuilder, output?: THREE.Node | null) {
232+
super.analyze(builder, output);
233+
this.#analyzeFunction(builder); // making sure it will find all TSL accessors
234+
}
235+
166236
generate(
167237
builder: THREE.NodeBuilder,
168238
output: string | null | undefined,
@@ -171,15 +241,16 @@ class TgpuFnNode<T> extends THREE.Node {
171241

172242
const nodeData = builder.getDataFromNode(this) as TgpuFnNodeData;
173243
const builderData = builderDataMap.get(builder) as BuilderData;
174-
const stageData = builderData.getStageData(builder.shaderStage);
244+
const stageData = builderData.getGenerateStageData(builder.shaderStage);
175245

176246
// Building dependencies
177-
for (const dep of nodeData.custom.dependencies) {
247+
const uniqueDeps = [...new Set(nodeData.custom.dependencies)];
248+
for (const dep of uniqueDeps) {
178249
dep.node.build(builder);
179250
}
180251
nodeData.custom.priorCode.build(builder);
181252

182-
for (const dep of nodeData.custom.dependencies) {
253+
for (const dep of uniqueDeps) {
183254
if (!dep.var) {
184255
continue;
185256
}
@@ -190,10 +261,9 @@ class TgpuFnNode<T> extends THREE.Node {
190261
builder.addLineFlowCode(`${varName} = ${varValue};\n`, this);
191262
}
192263

193-
if (output === 'property') {
194-
return nodeData.custom.functionId;
195-
}
196-
return `${nodeData.custom.functionId}()`;
264+
return output === 'property'
265+
? nodeData.custom.functionId
266+
: `${nodeData.custom.functionId}()`;
197267
}
198268
}
199269

@@ -206,7 +276,7 @@ export function toTSL(
206276
export class TSLAccessor<T extends d.AnyWgslData, TNode extends THREE.Node> {
207277
readonly #dataType: T;
208278

209-
readonly var: TgpuVar<'private', T> | undefined;
279+
#var: TgpuVar<'private', T> | undefined;
210280
readonly node: THREE.TSL.NodeObject<TNode>;
211281

212282
constructor(
@@ -216,31 +286,54 @@ export class TSLAccessor<T extends d.AnyWgslData, TNode extends THREE.Node> {
216286
this.node = node;
217287
this.#dataType = dataType;
218288

219-
// node.isTextureNode - temporary workaround for textures
220289
if (
221-
// @ts-expect-error: The properties exist on the node
222-
(!node.isStorageBufferNode && !node.isUniformNode) || node.isTextureNode
290+
// @ts-expect-error: they are assigned at runtime
291+
(!node.isStorageBufferNode && !node.isUniformNode) ||
292+
// @ts-expect-error: it is assigned at runtime
293+
node.isTextureNode
223294
) {
224-
this.var = tgpu.privateVar(dataType);
295+
this.#var = tgpu.privateVar(dataType);
225296
}
226297
}
227298

299+
get var(): TgpuVar<'private', T> | undefined {
300+
return this.#var;
301+
}
302+
228303
get $(): d.InferGPU<T> {
229304
const ctx = currentlyGeneratingFnNodeCtx;
230305

231306
if (!ctx) {
232307
throw new Error('Can only access TSL nodes on the GPU.');
233308
}
234309

310+
if (ctx.stageData.type === 'analyze') {
311+
this.node.traverse((node: THREE.Node) => {
312+
node.analyze(ctx.builder);
313+
});
314+
// dummy return, only for types to match
315+
return tgpu['~unstable'].rawCodeSnippet('', this.#dataType, 'runtime').$;
316+
}
317+
235318
// biome-ignore lint/suspicious/noExplicitAny: smh
236319
ctx.dependencies.push(this as any);
237320

321+
const builtNode = this.node.build(ctx.builder) as string;
322+
323+
// @ts-expect-error: it is assigned at runtime
324+
const trueVaryingNode = this.node.isVaryingNode &&
325+
builtNode.includes('varyings.');
326+
327+
if (trueVaryingNode) {
328+
this.#var = undefined; // cannot be checked earlier, ThreeJS is lazy
329+
}
330+
238331
if (this.var) {
239332
return this.var.$;
240333
}
241334

242335
return tgpu['~unstable'].rawCodeSnippet(
243-
this.node.build(ctx.builder) as string,
336+
builtNode,
244337
this.#dataType,
245338
).$;
246339
}
@@ -295,8 +388,14 @@ export const fromTSL = tgpu['~unstable'].comptime<
295388
if (!sharedBuilder) {
296389
sharedBuilder = new WGSLNodeBuilder();
297390
}
298-
const nodeType = node.getNodeType(sharedBuilder);
299-
391+
let nodeType: string | null = null;
392+
try { // sometimes it needs information (overrideNodes) from compilation context which is not present
393+
nodeType = node.getNodeType(sharedBuilder);
394+
} catch (e) {
395+
console.warn(
396+
`fromTSL: failed to infer node type via getNodeType; skipping type comparison.`,
397+
);
398+
}
300399
if (nodeType) {
301400
const wgslTypeFromTSL = sharedBuilder.getType(nodeType);
302401
if (wgslTypeFromTSL !== wgslTypeFromTgpu) {

0 commit comments

Comments
 (0)