Skip to content

Commit 55bbb41

Browse files
authored
fix: Execute tgpu.comptime in "normal" mode (#1957)
1 parent 9334be9 commit 55bbb41

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

packages/typegpu/src/core/function/comptime.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import type { DualFn } from '../../data/dualFn.ts';
22
import type { MapValueToSnippet } from '../../data/snippet.ts';
33
import { WgslTypeError } from '../../errors.ts';
4-
import { inCodegenMode } from '../../execMode.ts';
4+
import { getResolutionCtx } from '../../execMode.ts';
55
import { setName, type TgpuNamable } from '../../shared/meta.ts';
66
import { $getNameForward, $internal } from '../../shared/symbols.ts';
77
import { coerceToSnippet } from '../../tgsl/generationHelpers.ts';
8-
import { isKnownAtComptime } from '../../types.ts';
8+
import { isKnownAtComptime, NormalState } from '../../types.ts';
99

1010
export type TgpuComptime<T extends (...args: never[]) => unknown> =
1111
& DualFn<T>
@@ -52,8 +52,14 @@ export function comptime<T extends (...args: never[]) => unknown>(
5252
};
5353

5454
const impl = ((...args: Parameters<T>) => {
55-
if (inCodegenMode()) {
56-
return gpuImpl(...args as MapValueToSnippet<Parameters<T>>);
55+
const ctx = getResolutionCtx();
56+
if (ctx?.mode.type === 'codegen') {
57+
ctx.pushMode(new NormalState());
58+
try {
59+
return gpuImpl(...args as MapValueToSnippet<Parameters<T>>);
60+
} finally {
61+
ctx.popMode('normal');
62+
}
5763
}
5864
return func(...args);
5965
}) as TgpuComptime<T>;

packages/typegpu/tests/tgsl/comptime.test.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,20 @@ describe('comptime', () => {
4848
}"
4949
`);
5050
});
51+
52+
it('should work in "normal" mode', () => {
53+
const stagger = tgpu['~unstable'].comptime((v: d.v3f) => {
54+
return v.add(d.vec3f(0, 1, 2));
55+
});
56+
57+
const myFn = tgpu.fn([], d.f32)(() => {
58+
return stagger(d.vec3f(2)).z;
59+
});
60+
61+
expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
62+
"fn myFn() -> f32 {
63+
return 4f;
64+
}"
65+
`);
66+
});
5167
});

0 commit comments

Comments
 (0)