Skip to content

Commit a08371c

Browse files
authored
docs: Optimize 3d slime mold example (#1900)
1 parent 0d092d7 commit a08371c

2 files changed

Lines changed: 223 additions & 167 deletions

File tree

apps/typegpu-docs/src/examples/simulation/slime-mold-3d/index.ts

Lines changed: 85 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ import { randf } from '@typegpu/noise';
66
import * as m from 'wgpu-matrix';
77

88
const root = await tgpu.init({
9-
device: {
10-
optionalFeatures: ['float32-filterable'],
11-
},
9+
device: { optionalFeatures: ['float32-filterable'] },
1210
});
1311
const canFilter = root.enabledFeatures.has('float32-filterable');
1412
const device = root.device;
@@ -32,8 +30,8 @@ const CAMERA_FOV_DEGREES = 60;
3230
const CAMERA_DISTANCE_MULTIPLIER = 1.5;
3331
const CAMERA_INITIAL_ANGLE = Math.PI / 4;
3432

35-
const RAYMARCH_STEPS = 128;
36-
const DENSITY_MULTIPLIER = 0.05;
33+
const RAYMARCH_STEPS = 48;
34+
const DENSITY_MULTIPLIER = 0.1;
3735

3836
const RANDOM_DIRECTION_WEIGHT = 0.3;
3937
const CENTER_BIAS_WEIGHT = 0.7;
@@ -105,15 +103,19 @@ const Params = d.struct({
105103
evaporationRate: d.f32,
106104
});
107105

108-
const agentsData = root.createMutable(d.arrayOf(Agent, NUM_AGENTS));
106+
const agentsDataBuffers = [0, 1].map(() =>
107+
root.createBuffer(d.arrayOf(Agent, NUM_AGENTS)).$usage('storage')
108+
);
109109

110+
const mutableAgentsDataBuffers = agentsDataBuffers.map((b) => b.as('mutable'));
110111
root['~unstable'].createGuardedComputePipeline((x) => {
111112
'use gpu';
112113
randf.seed(x / NUM_AGENTS);
113114
const pos = randf.inUnitSphere().mul(resolution.x / 4).add(resolution.div(2));
114115
const center = resolution.div(2);
115116
const dir = std.normalize(center.sub(pos));
116-
agentsData.$[x] = Agent({ position: pos, direction: dir });
117+
mutableAgentsDataBuffers[0].$[x] = Agent({ position: pos, direction: dir });
118+
mutableAgentsDataBuffers[1].$[x] = Agent({ position: pos, direction: dir });
117119
}).dispatchThreads(NUM_AGENTS);
118120

119121
const params = root.createUniform(Params, {
@@ -136,8 +138,16 @@ const textures = [0, 1].map(() =>
136138
);
137139

138140
const computeLayout = tgpu.bindGroupLayout({
141+
oldAgents: { storage: d.arrayOf(Agent), access: 'readonly' },
139142
oldState: { storageTexture: d.textureStorage3d('r32float', 'read-only') },
143+
newAgents: { storage: d.arrayOf(Agent), access: 'mutable' },
144+
newState: { storageTexture: d.textureStorage3d('r32float', 'write-only') },
145+
});
146+
147+
const blurLayout = tgpu.bindGroupLayout({
148+
oldState: { texture: d.texture3d() },
140149
newState: { storageTexture: d.textureStorage3d('r32float', 'write-only') },
150+
sampler: { sampler: 'filtering' },
141151
});
142152

143153
const renderLayout = tgpu.bindGroupLayout({
@@ -223,24 +233,18 @@ const updateAgents = tgpu['~unstable'].computeFn({
223233
const dims = std.textureDimensions(computeLayout.$.oldState);
224234
const dimsf = d.vec3f(dims);
225235

226-
const agent = agentsData.$[gid.x];
227-
const random = randf.sample();
236+
const agent = computeLayout.$.oldAgents[gid.x];
228237

229238
let direction = std.normalize(agent.direction);
230239
const senseResult = sense3D(agent.position, direction);
231-
232-
if (senseResult.totalWeight > 0.01) {
233-
const targetDir = std.normalize(senseResult.weightedDir);
234-
direction = std.normalize(
235-
direction.add(targetDir.mul(params.$.turnSpeed * params.$.deltaTime)),
236-
);
237-
} else {
238-
const perp = getPerpendicular(direction);
239-
const randomOffset = perp.mul(
240-
(random * 2 - 1) * params.$.turnSpeed * params.$.deltaTime,
241-
);
242-
direction = std.normalize(direction.add(randomOffset));
243-
}
240+
const targetDirection = std.select(
241+
randf.onHemisphere(direction),
242+
std.normalize(senseResult.weightedDir),
243+
senseResult.totalWeight > 0.01,
244+
);
245+
direction = std.normalize(direction.add(
246+
targetDirection.mul(params.$.turnSpeed * params.$.deltaTime),
247+
));
244248

245249
const newPos = agent.position.add(
246250
direction.mul(params.$.moveSpeed * params.$.deltaTime),
@@ -292,7 +296,7 @@ const updateAgents = tgpu['~unstable'].computeFn({
292296
);
293297
}
294298

295-
agentsData.$[gid.x] = Agent({
299+
computeLayout.$.newAgents[gid.x] = Agent({
296300
position: newPos,
297301
direction,
298302
});
@@ -306,52 +310,47 @@ const updateAgents = tgpu['~unstable'].computeFn({
306310
);
307311
});
308312

313+
const sampler = root['~unstable'].createSampler({
314+
magFilter: canFilter ? 'linear' : 'nearest',
315+
minFilter: canFilter ? 'linear' : 'nearest',
316+
});
317+
318+
const getSummand = tgpu.fn([d.vec3f, d.vec3f], d.f32)((uv, offset) =>
319+
std.textureSampleLevel(
320+
blurLayout.$.oldState,
321+
blurLayout.$.sampler,
322+
uv.add(offset),
323+
0,
324+
).x
325+
);
326+
309327
const blur = tgpu['~unstable'].computeFn({
310328
in: { gid: d.builtin.globalInvocationId },
311329
workgroupSize: BLUR_WORKGROUP_SIZE,
312330
})(({ gid }) => {
313-
const dims = std.textureDimensions(computeLayout.$.oldState);
331+
const dims = d.vec3u(std.textureDimensions(blurLayout.$.oldState));
314332
if (gid.x >= dims.x || gid.y >= dims.y || gid.z >= dims.z) return;
315333

334+
const uv = d.vec3f(gid).add(0.5).div(d.vec3f(dims));
335+
316336
let sum = d.f32();
317-
let count = d.f32();
318-
319-
for (let offsetZ = -1; offsetZ <= 1; offsetZ++) {
320-
for (let offsetY = -1; offsetY <= 1; offsetY++) {
321-
for (let offsetX = -1; offsetX <= 1; offsetX++) {
322-
const samplePos = d.vec3i(gid.xyz).add(
323-
d.vec3i(offsetX, offsetY, offsetZ),
324-
);
325-
const dimsi = d.vec3i(dims);
326-
327-
if (
328-
samplePos.x >= 0 && samplePos.x < dimsi.x &&
329-
samplePos.y >= 0 && samplePos.y < dimsi.y &&
330-
samplePos.z >= 0 && samplePos.z < dimsi.z
331-
) {
332-
const value =
333-
std.textureLoad(computeLayout.$.oldState, d.vec3u(samplePos)).x;
334-
sum = sum + value;
335-
count = count + 1;
336-
}
337-
}
338-
}
339-
}
340337

341-
const blurred = sum / count;
338+
sum += getSummand(uv, d.vec3f(-1, 0, 0).div(d.vec3f(dims)));
339+
sum += getSummand(uv, d.vec3f(1, 0, 0).div(d.vec3f(dims)));
340+
sum += getSummand(uv, d.vec3f(0, -1, 0).div(d.vec3f(dims)));
341+
sum += getSummand(uv, d.vec3f(0, 1, 0).div(d.vec3f(dims)));
342+
sum += getSummand(uv, d.vec3f(0, 0, -1).div(d.vec3f(dims)));
343+
sum += getSummand(uv, d.vec3f(0, 0, 1).div(d.vec3f(dims)));
344+
345+
const blurred = sum / 6.0;
342346
const newValue = std.saturate(blurred - params.$.evaporationRate);
343347
std.textureStore(
344-
computeLayout.$.newState,
348+
blurLayout.$.newState,
345349
gid.xyz,
346350
d.vec4f(newValue, 0, 0, 1),
347351
);
348352
});
349353

350-
const sampler = root['~unstable'].createSampler({
351-
magFilter: canFilter ? 'linear' : 'nearest',
352-
minFilter: canFilter ? 'linear' : 'nearest',
353-
});
354-
355354
// Ray-box intersection
356355
const rayBoxIntersection = (
357356
rayOrigin: d.v3f,
@@ -375,6 +374,7 @@ const fragmentShader = tgpu['~unstable'].fragmentFn({
375374
in: { uv: d.vec2f },
376375
out: d.vec4f,
377376
})(({ uv }) => {
377+
randf.seed2(uv);
378378
const ndc = d.vec2f(uv.x * 2 - 1, 1 - uv.y * 2);
379379
const ndcNear = d.vec4f(ndc, -1, 1);
380380
const ndcFar = d.vec4f(ndc, 1, 1);
@@ -393,11 +393,23 @@ const fragmentShader = tgpu['~unstable'].fragmentFn({
393393
return d.vec4f();
394394
}
395395

396-
// March params
397-
const tStart = std.max(isect.tNear, 0);
396+
const jitter = randf.sample() * 20;
397+
const tStart = std.max(isect.tNear + jitter, jitter);
398398
const tEnd = isect.tFar;
399-
const numSteps = RAYMARCH_STEPS;
400-
const stepSize = (tEnd - tStart) / numSteps;
399+
400+
const intersectionLength = tEnd - tStart;
401+
const baseStepsPerUnit = d.f32(0.3);
402+
const minSteps = d.i32(8);
403+
const maxSteps = d.i32(RAYMARCH_STEPS);
404+
405+
const adaptiveSteps = std.clamp(
406+
d.i32(intersectionLength * baseStepsPerUnit),
407+
minSteps,
408+
maxSteps,
409+
);
410+
411+
const numSteps = adaptiveSteps;
412+
const stepSize = intersectionLength / d.f32(numSteps);
401413

402414
const thresholdLo = d.f32(0.06);
403415
const thresholdHi = d.f32(0.25);
@@ -411,11 +423,8 @@ const fragmentShader = tgpu['~unstable'].fragmentFn({
411423

412424
const TMin = d.f32(1e-3);
413425

414-
for (let i = 0; i < numSteps; i++) {
415-
if (transmittance <= TMin) {
416-
break;
417-
}
418-
426+
let i = d.i32(0);
427+
while (i < numSteps && transmittance > TMin) {
419428
const t = tStart + (d.f32(i) + 0.5) * stepSize;
420429
const pos = rayOrigin.add(rayDir.mul(t));
421430
const texCoord = pos.div(resolution);
@@ -433,6 +442,8 @@ const fragmentShader = tgpu['~unstable'].fragmentFn({
433442

434443
accum = accum.add(contrib.mul(transmittance));
435444
transmittance = transmittance * (1 - alphaSrc);
445+
446+
i += 1;
436447
}
437448

438449
const alpha = 1 - transmittance;
@@ -454,8 +465,18 @@ const blurPipeline = root['~unstable']
454465

455466
const bindGroups = [0, 1].map((i) =>
456467
root.createBindGroup(computeLayout, {
468+
oldAgents: agentsDataBuffers[i],
469+
oldState: textures[i],
470+
newAgents: agentsDataBuffers[1 - i],
471+
newState: textures[1 - i],
472+
})
473+
);
474+
475+
const blurBindGroups = [0, 1].map((i) =>
476+
root.createBindGroup(blurLayout, {
457477
oldState: textures[i],
458478
newState: textures[1 - i],
479+
sampler: sampler,
459480
})
460481
);
461482

@@ -476,7 +497,7 @@ function frame() {
476497
params.writePartial({ deltaTime });
477498

478499
blurPipeline
479-
.with(bindGroups[currentTexture])
500+
.with(blurBindGroups[currentTexture])
480501
.dispatchWorkgroups(
481502
Math.ceil(resolution.x / BLUR_WORKGROUP_SIZE[0]),
482503
Math.ceil(resolution.y / BLUR_WORKGROUP_SIZE[1]),

0 commit comments

Comments
 (0)