@@ -7,15 +7,16 @@ import tgpu, {
77} from 'typegpu' ;
88import { fullScreenTriangle } from 'typegpu/common' ;
99import * as d from 'typegpu/data' ;
10- import { MODEL_HEIGHT , MODEL_WIDTH , prepareSession } from './model.ts' ;
10+ import { MODEL_HEIGHT , MODEL_WIDTH , MODELS , prepareSession } from './model.ts' ;
1111import {
1212 blockDim ,
1313 blurLayout ,
1414 drawWithMaskLayout ,
15+ flipSlot ,
1516 generateMaskLayout ,
17+ Params ,
18+ paramsAccessor ,
1619 prepareModelInputLayout ,
17- sampleBiasSlot ,
18- useGaussianSlot ,
1920} from './schemas.ts' ;
2021import {
2122 computeFn ,
@@ -62,7 +63,7 @@ const oldRequestAdapter = navigator.gpu.requestAdapter;
6263const oldRequestDevice = adapter . requestDevice ;
6364navigator . gpu . requestAdapter = async ( ) => adapter ;
6465adapter . requestDevice = async ( ) => device ;
65- const root = await tgpu . initFromDevice ( { device } ) ;
66+ const root = tgpu . initFromDevice ( { device } ) ;
6667const context = canvas . getContext ( 'webgpu' ) as GPUCanvasContext ;
6768const presentationFormat = navigator . gpu . getPreferredCanvasFormat ( ) ;
6869
@@ -76,11 +77,13 @@ context.configure({
7677
7778let blurStrength = 5 ;
7879let useGaussianBlur = false ;
80+ let useSquareCrop = false ;
7981
80- const zeroBuffer = root . createBuffer ( d . u32 , 0 ) . $usage ( 'uniform' ) ;
81- const oneBuffer = root . createBuffer ( d . u32 , 1 ) . $usage ( 'uniform' ) ;
82- const useGaussianUniform = root . createUniform ( d . u32 , 0 ) ;
83- const sampleBiasUniform = root . createUniform ( d . f32 , 0 ) ;
82+ const paramsUniform = root . createUniform ( Params , {
83+ cropBounds : d . vec4f ( 0 , 0 , 1 , 1 ) ,
84+ useGaussian : 0 ,
85+ sampleBias : blurStrength ,
86+ } ) ;
8487
8588const sampler = root [ '~unstable' ] . createSampler ( {
8689 magFilter : 'linear' ,
@@ -123,27 +126,48 @@ let blurBindGroups: TgpuBindGroup<typeof blurLayout.entries>[];
123126// pipelines
124127
125128const prepareModelInputPipeline = root [ '~unstable' ]
129+ . with ( paramsAccessor , paramsUniform )
126130 . createGuardedComputePipeline (
127131 prepareModelInput ,
128132 ) ;
129133
130- const session = await prepareSession (
134+ let currentModelIndex = 0 ;
135+ let session = await prepareSession (
131136 root . unwrap ( modelInputBuffer ) ,
132137 root . unwrap ( modelOutputBuffer ) ,
138+ MODELS [ currentModelIndex ] ,
133139) ;
140+ let isLoadingModel = false ;
141+
142+ async function switchModel ( modelIndex : number ) {
143+ if ( isLoadingModel || modelIndex === currentModelIndex ) return ;
144+ isLoadingModel = true ;
145+
146+ const oldSession = session ;
147+ currentModelIndex = modelIndex ;
148+ session = await prepareSession (
149+ root . unwrap ( modelInputBuffer ) ,
150+ root . unwrap ( modelOutputBuffer ) ,
151+ MODELS [ currentModelIndex ] ,
152+ ) ;
153+ oldSession . release ( ) ;
154+ isLoadingModel = false ;
155+ }
134156
135157const generateMaskFromOutputPipeline = root [ '~unstable' ]
136158 . createGuardedComputePipeline (
137159 generateMaskFromOutput ,
138160 ) ;
139161
140- const blurPipeline = root [ '~unstable' ]
141- . withCompute ( computeFn )
142- . createPipeline ( ) ;
162+ const blurPipelines = [ false , true ] . map ( ( flip ) =>
163+ root [ '~unstable' ]
164+ . with ( flipSlot , flip )
165+ . withCompute ( computeFn )
166+ . createPipeline ( )
167+ ) ;
143168
144169const drawWithMaskPipeline = root [ '~unstable' ]
145- . with ( useGaussianSlot , useGaussianUniform )
146- . with ( sampleBiasSlot , sampleBiasUniform )
170+ . with ( paramsAccessor , paramsUniform )
147171 . withVertex ( fullScreenTriangle , { } )
148172 . withFragment ( drawWithMaskFragment , { format : presentationFormat } )
149173 . createPipeline ( ) ;
@@ -153,7 +177,7 @@ const drawWithMaskPipeline = root['~unstable']
153177let calculateMaskCallbackId : number | undefined ;
154178
155179async function processCalculateMask ( ) {
156- if ( video . readyState < 2 ) {
180+ if ( video . readyState < 2 || isLoadingModel ) {
157181 calculateMaskCallbackId = video . requestVideoFrameCallback (
158182 processCalculateMask ,
159183 ) ;
@@ -181,6 +205,29 @@ async function processCalculateMask() {
181205calculateMaskCallbackId = video . requestVideoFrameCallback ( processCalculateMask ) ;
182206
183207// frame
208+ function updateCropBounds ( aspectRatio : number ) {
209+ let uvMinX = 0 ;
210+ let uvMinY = 0 ;
211+ let uvMaxX = 1 ;
212+ let uvMaxY = 1 ;
213+
214+ if ( useSquareCrop ) {
215+ if ( aspectRatio > 1 ) {
216+ // wide (e.g. 16:9) - crop horizontally
217+ const cropWidth = 1 / aspectRatio ; // width of square in UV space
218+ uvMinX = ( 1 - cropWidth ) / 2 ;
219+ uvMaxX = uvMinX + cropWidth ;
220+ } else if ( aspectRatio < 1 ) {
221+ // tall - crop vertically
222+ const cropHeight = aspectRatio ; // height of square in UV space
223+ uvMinY = ( 1 - cropHeight ) / 2 ;
224+ uvMaxY = uvMinY + cropHeight ;
225+ }
226+ }
227+ paramsUniform . writePartial ( {
228+ cropBounds : d . vec4f ( uvMinX , uvMinY , uvMaxX , uvMaxY ) ,
229+ } ) ;
230+ }
184231
185232function onVideoChange ( size : { width : number ; height : number } ) {
186233 const aspectRatio = size . width / size . height ;
@@ -190,17 +237,18 @@ function onVideoChange(size: { width: number; height: number }) {
190237 canvas . parentElement . style . height =
191238 `min(100cqh, calc(100cqw/(${ aspectRatio } )))` ;
192239 }
240+
241+ updateCropBounds ( aspectRatio ) ;
242+
193243 blurredTextures = [ 0 , 1 ] . map ( ( ) =>
194244 root [ '~unstable' ] . createTexture ( {
195245 size : [ size . width , size . height ] ,
196246 format : 'rgba8unorm' ,
197- dimension : '2d' ,
198247 mipLevelCount : 10 ,
199248 } ) . $usage ( 'sampled' , 'render' , 'storage' )
200249 ) ;
201250 blurBindGroups = [
202251 root . createBindGroup ( blurLayout , {
203- flip : zeroBuffer ,
204252 inTexture : blurredTextures [ 0 ] ,
205253 outTexture : blurredTextures [ 1 ] . createView (
206254 d . textureStorage2d ( 'rgba8unorm' , 'read-only' ) ,
@@ -209,7 +257,6 @@ function onVideoChange(size: { width: number; height: number }) {
209257 sampler,
210258 } ) ,
211259 root . createBindGroup ( blurLayout , {
212- flip : oneBuffer ,
213260 inTexture : blurredTextures [ 1 ] ,
214261 outTexture : blurredTextures [ 0 ] . createView (
215262 d . textureStorage2d ( 'rgba8unorm' , 'read-only' ) ,
@@ -248,13 +295,13 @@ async function processVideoFrame(
248295
249296 if ( useGaussianBlur ) {
250297 for ( const _ of Array ( blurStrength * 2 ) ) {
251- blurPipeline
298+ blurPipelines [ 0 ]
252299 . with ( blurBindGroups [ 0 ] )
253300 . dispatchWorkgroups (
254301 Math . ceil ( frameWidth / blockDim ) ,
255302 Math . ceil ( frameHeight / 4 ) ,
256303 ) ;
257- blurPipeline
304+ blurPipelines [ 1 ]
258305 . with ( blurBindGroups [ 1 ] )
259306 . dispatchWorkgroups (
260307 Math . ceil ( frameHeight / blockDim ) ,
@@ -287,12 +334,22 @@ videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame);
287334// #region Example controls & Cleanup
288335
289336export const controls = {
337+ model : {
338+ initial : MODELS [ 0 ] . name ,
339+ options : MODELS . map ( ( m ) => m . name ) ,
340+ async onSelectChange ( value : string ) {
341+ const index = MODELS . findIndex ( ( m ) => m . name === value ) ;
342+ if ( index !== - 1 ) {
343+ await switchModel ( index ) ;
344+ }
345+ } ,
346+ } ,
290347 'blur type' : {
291348 initial : 'mipmaps' ,
292349 options : [ 'mipmaps' , 'gaussian' ] ,
293350 async onSelectChange ( value : string ) {
294351 useGaussianBlur = value === 'gaussian' ;
295- useGaussianUniform . write ( useGaussianBlur ? 1 : 0 ) ;
352+ paramsUniform . writePartial ( { useGaussian : useGaussianBlur ? 1 : 0 } ) ;
296353 } ,
297354 } ,
298355 'blur strength' : {
@@ -302,7 +359,16 @@ export const controls = {
302359 step : 1 ,
303360 onSliderChange ( newValue : number ) {
304361 blurStrength = newValue ;
305- sampleBiasUniform . write ( blurStrength ) ;
362+ paramsUniform . writePartial ( { sampleBias : blurStrength } ) ;
363+ } ,
364+ } ,
365+ 'square crop' : {
366+ initial : useSquareCrop ,
367+ onToggleChange ( value : boolean ) {
368+ useSquareCrop = value ;
369+ if ( lastFrameSize ) {
370+ updateCropBounds ( lastFrameSize . width / lastFrameSize . height ) ;
371+ }
306372 } ,
307373 } ,
308374} ;
0 commit comments