Skip to content

Commit a349f37

Browse files
authored
docs: Background segmentation - custom models (#1951)
1 parent 40573f1 commit a349f37

4 files changed

Lines changed: 270 additions & 45 deletions

File tree

apps/typegpu-docs/src/examples/image-processing/background-segmentation/index.ts

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ import tgpu, {
77
} from 'typegpu';
88
import { fullScreenTriangle } from 'typegpu/common';
99
import * as d from 'typegpu/data';
10-
import { MODEL_HEIGHT, MODEL_WIDTH, prepareSession } from './model.ts';
10+
import { MODEL_HEIGHT, MODEL_WIDTH, MODELS, prepareSession } from './model.ts';
1111
import {
1212
blockDim,
1313
blurLayout,
1414
drawWithMaskLayout,
15+
flipSlot,
1516
generateMaskLayout,
17+
Params,
18+
paramsAccessor,
1619
prepareModelInputLayout,
17-
sampleBiasSlot,
18-
useGaussianSlot,
1920
} from './schemas.ts';
2021
import {
2122
computeFn,
@@ -62,7 +63,7 @@ const oldRequestAdapter = navigator.gpu.requestAdapter;
6263
const oldRequestDevice = adapter.requestDevice;
6364
navigator.gpu.requestAdapter = async () => adapter;
6465
adapter.requestDevice = async () => device;
65-
const root = await tgpu.initFromDevice({ device });
66+
const root = tgpu.initFromDevice({ device });
6667
const context = canvas.getContext('webgpu') as GPUCanvasContext;
6768
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
6869

@@ -76,11 +77,13 @@ context.configure({
7677

7778
let blurStrength = 5;
7879
let useGaussianBlur = false;
80+
let useSquareCrop = false;
7981

80-
const zeroBuffer = root.createBuffer(d.u32, 0).$usage('uniform');
81-
const oneBuffer = root.createBuffer(d.u32, 1).$usage('uniform');
82-
const useGaussianUniform = root.createUniform(d.u32, 0);
83-
const sampleBiasUniform = root.createUniform(d.f32, 0);
82+
const paramsUniform = root.createUniform(Params, {
83+
cropBounds: d.vec4f(0, 0, 1, 1),
84+
useGaussian: 0,
85+
sampleBias: blurStrength,
86+
});
8487

8588
const sampler = root['~unstable'].createSampler({
8689
magFilter: 'linear',
@@ -123,27 +126,48 @@ let blurBindGroups: TgpuBindGroup<typeof blurLayout.entries>[];
123126
// pipelines
124127

125128
const prepareModelInputPipeline = root['~unstable']
129+
.with(paramsAccessor, paramsUniform)
126130
.createGuardedComputePipeline(
127131
prepareModelInput,
128132
);
129133

130-
const session = await prepareSession(
134+
let currentModelIndex = 0;
135+
let session = await prepareSession(
131136
root.unwrap(modelInputBuffer),
132137
root.unwrap(modelOutputBuffer),
138+
MODELS[currentModelIndex],
133139
);
140+
let isLoadingModel = false;
141+
142+
async function switchModel(modelIndex: number) {
143+
if (isLoadingModel || modelIndex === currentModelIndex) return;
144+
isLoadingModel = true;
145+
146+
const oldSession = session;
147+
currentModelIndex = modelIndex;
148+
session = await prepareSession(
149+
root.unwrap(modelInputBuffer),
150+
root.unwrap(modelOutputBuffer),
151+
MODELS[currentModelIndex],
152+
);
153+
oldSession.release();
154+
isLoadingModel = false;
155+
}
134156

135157
const generateMaskFromOutputPipeline = root['~unstable']
136158
.createGuardedComputePipeline(
137159
generateMaskFromOutput,
138160
);
139161

140-
const blurPipeline = root['~unstable']
141-
.withCompute(computeFn)
142-
.createPipeline();
162+
const blurPipelines = [false, true].map((flip) =>
163+
root['~unstable']
164+
.with(flipSlot, flip)
165+
.withCompute(computeFn)
166+
.createPipeline()
167+
);
143168

144169
const drawWithMaskPipeline = root['~unstable']
145-
.with(useGaussianSlot, useGaussianUniform)
146-
.with(sampleBiasSlot, sampleBiasUniform)
170+
.with(paramsAccessor, paramsUniform)
147171
.withVertex(fullScreenTriangle, {})
148172
.withFragment(drawWithMaskFragment, { format: presentationFormat })
149173
.createPipeline();
@@ -153,7 +177,7 @@ const drawWithMaskPipeline = root['~unstable']
153177
let calculateMaskCallbackId: number | undefined;
154178

155179
async function processCalculateMask() {
156-
if (video.readyState < 2) {
180+
if (video.readyState < 2 || isLoadingModel) {
157181
calculateMaskCallbackId = video.requestVideoFrameCallback(
158182
processCalculateMask,
159183
);
@@ -181,6 +205,29 @@ async function processCalculateMask() {
181205
calculateMaskCallbackId = video.requestVideoFrameCallback(processCalculateMask);
182206

183207
// frame
208+
function updateCropBounds(aspectRatio: number) {
209+
let uvMinX = 0;
210+
let uvMinY = 0;
211+
let uvMaxX = 1;
212+
let uvMaxY = 1;
213+
214+
if (useSquareCrop) {
215+
if (aspectRatio > 1) {
216+
// wide (e.g. 16:9) - crop horizontally
217+
const cropWidth = 1 / aspectRatio; // width of square in UV space
218+
uvMinX = (1 - cropWidth) / 2;
219+
uvMaxX = uvMinX + cropWidth;
220+
} else if (aspectRatio < 1) {
221+
// tall - crop vertically
222+
const cropHeight = aspectRatio; // height of square in UV space
223+
uvMinY = (1 - cropHeight) / 2;
224+
uvMaxY = uvMinY + cropHeight;
225+
}
226+
}
227+
paramsUniform.writePartial({
228+
cropBounds: d.vec4f(uvMinX, uvMinY, uvMaxX, uvMaxY),
229+
});
230+
}
184231

185232
function onVideoChange(size: { width: number; height: number }) {
186233
const aspectRatio = size.width / size.height;
@@ -190,17 +237,18 @@ function onVideoChange(size: { width: number; height: number }) {
190237
canvas.parentElement.style.height =
191238
`min(100cqh, calc(100cqw/(${aspectRatio})))`;
192239
}
240+
241+
updateCropBounds(aspectRatio);
242+
193243
blurredTextures = [0, 1].map(() =>
194244
root['~unstable'].createTexture({
195245
size: [size.width, size.height],
196246
format: 'rgba8unorm',
197-
dimension: '2d',
198247
mipLevelCount: 10,
199248
}).$usage('sampled', 'render', 'storage')
200249
);
201250
blurBindGroups = [
202251
root.createBindGroup(blurLayout, {
203-
flip: zeroBuffer,
204252
inTexture: blurredTextures[0],
205253
outTexture: blurredTextures[1].createView(
206254
d.textureStorage2d('rgba8unorm', 'read-only'),
@@ -209,7 +257,6 @@ function onVideoChange(size: { width: number; height: number }) {
209257
sampler,
210258
}),
211259
root.createBindGroup(blurLayout, {
212-
flip: oneBuffer,
213260
inTexture: blurredTextures[1],
214261
outTexture: blurredTextures[0].createView(
215262
d.textureStorage2d('rgba8unorm', 'read-only'),
@@ -248,13 +295,13 @@ async function processVideoFrame(
248295

249296
if (useGaussianBlur) {
250297
for (const _ of Array(blurStrength * 2)) {
251-
blurPipeline
298+
blurPipelines[0]
252299
.with(blurBindGroups[0])
253300
.dispatchWorkgroups(
254301
Math.ceil(frameWidth / blockDim),
255302
Math.ceil(frameHeight / 4),
256303
);
257-
blurPipeline
304+
blurPipelines[1]
258305
.with(blurBindGroups[1])
259306
.dispatchWorkgroups(
260307
Math.ceil(frameHeight / blockDim),
@@ -287,12 +334,22 @@ videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame);
287334
// #region Example controls & Cleanup
288335

289336
export const controls = {
337+
model: {
338+
initial: MODELS[0].name,
339+
options: MODELS.map((m) => m.name),
340+
async onSelectChange(value: string) {
341+
const index = MODELS.findIndex((m) => m.name === value);
342+
if (index !== -1) {
343+
await switchModel(index);
344+
}
345+
},
346+
},
290347
'blur type': {
291348
initial: 'mipmaps',
292349
options: ['mipmaps', 'gaussian'],
293350
async onSelectChange(value: string) {
294351
useGaussianBlur = value === 'gaussian';
295-
useGaussianUniform.write(useGaussianBlur ? 1 : 0);
352+
paramsUniform.writePartial({ useGaussian: useGaussianBlur ? 1 : 0 });
296353
},
297354
},
298355
'blur strength': {
@@ -302,7 +359,16 @@ export const controls = {
302359
step: 1,
303360
onSliderChange(newValue: number) {
304361
blurStrength = newValue;
305-
sampleBiasUniform.write(blurStrength);
362+
paramsUniform.writePartial({ sampleBias: blurStrength });
363+
},
364+
},
365+
'square crop': {
366+
initial: useSquareCrop,
367+
onToggleChange(value: boolean) {
368+
useSquareCrop = value;
369+
if (lastFrameSize) {
370+
updateCropBounds(lastFrameSize.width / lastFrameSize.height);
371+
}
306372
},
307373
},
308374
};

0 commit comments

Comments
 (0)