diff --git a/packages/typegpu/src/core/function/createCallableSchema.ts b/packages/typegpu/src/core/function/createCallableSchema.ts index 970c950fa8..369bca3f02 100644 --- a/packages/typegpu/src/core/function/createCallableSchema.ts +++ b/packages/typegpu/src/core/function/createCallableSchema.ts @@ -1,4 +1,4 @@ -import { type MapValueToSnippet, type ResolvedSnippet, snip } from '../../data/snippet.ts'; +import { type MapValueToSnippet, snip, type Snippet } from '../../data/snippet.ts'; import { type BaseData, isPtr } from '../../data/wgslTypes.ts'; import { setName } from '../../shared/meta.ts'; import { $gpuCallable } from '../../shared/symbols.ts'; @@ -12,10 +12,7 @@ interface CallableSchemaOptions { readonly name: string; readonly schema: () => BaseData; readonly normalImpl: T; - readonly codegenImpl: ( - ctx: ResolutionCtx, - args: MapValueToSnippet>, - ) => ResolvedSnippet; + readonly codegenImpl: (ctx: ResolutionCtx, args: MapValueToSnippet>) => Snippet; readonly argTypes: ( ...inArgTypes: MapValueToDataType> ) => (BaseData | BaseData[])[]; diff --git a/packages/typegpu/src/data/numeric.ts b/packages/typegpu/src/data/numeric.ts index d49368843b..b898346306 100644 --- a/packages/typegpu/src/data/numeric.ts +++ b/packages/typegpu/src/data/numeric.ts @@ -1,6 +1,8 @@ import { $internal } from '../shared/symbols.ts'; +import { isBool, isNumericSchema } from './wgslTypes.ts'; import type { AbstractFloat, AbstractInt, Bool, F16, F32, I32, U16, U32 } from './wgslTypes.ts'; import { callableSchema } from '../core/function/createCallableSchema.ts'; +import { snip } from './snippet.ts'; export const abstractInt = { [$internal]: {}, @@ -22,16 +24,34 @@ const boolCast = callableSchema({ name: 'bool', schema: () => bool, argTypes: (arg) => (arg ? [arg] : []), - normalImpl(v?: number | boolean) { - if (v === undefined) { - return false; + normalImpl(v?: unknown) { + return !!v; + }, + codegenImpl: (ctx, [arg]) => { + if (arg === undefined) { + return snip(false, bool, 'constant'); } - if (typeof v === 'boolean') { - return v; + + const { dataType } = arg; + + if (isBool(dataType)) { + return ctx.gen.typeInstantiation(bool, [arg]); } - return !!v; + + if (isNumericSchema(dataType)) { + const argStr = ctx.resolveSnippet(arg).value; + const resultStr = `bool(${argStr})`; + + const nanGuardedStr = + dataType.type === 'f32' || dataType.type === 'f16' + ? `(((bitcast(${dataType.type === 'f16' ? `f32(${argStr})` : argStr}) << 1u) - 1u) < 0xff000000)` + : resultStr; + + return snip(nanGuardedStr, bool, 'runtime'); + } + + return snip(true, bool, 'constant'); }, - codegenImpl: (ctx, args) => ctx.gen.typeInstantiation(bool, args), }); /** diff --git a/packages/typegpu/src/data/wgslTypes.ts b/packages/typegpu/src/data/wgslTypes.ts index dc85031925..e7f4a3e6de 100644 --- a/packages/typegpu/src/data/wgslTypes.ts +++ b/packages/typegpu/src/data/wgslTypes.ts @@ -602,7 +602,7 @@ export type mBaseForVec = T extends v2f * Boolean schema representing a single WGSL bool value. * Cannot be used inside buffers as it is not host-shareable. */ -export interface Bool extends BaseData, DualFn<(v?: number | boolean) => boolean> { +export interface Bool extends BaseData, DualFn<(v?: unknown) => boolean> { readonly type: 'bool'; // Type-tokens, not available at runtime diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 7c7c06d7c8..c8cf4e251b 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -162,29 +162,13 @@ const unaryOpCodeToCodegen = { throw new Error('The unary operator `!` expects 1 argument, but 0 were provided.'); } - if (isKnownAtComptime(argExpr)) { - return snip(!argExpr.value, bool, 'constant'); - } - - const { value, dataType } = argExpr; - const argStr = ctx.resolve(value, dataType).value; + const convArg = bool[$gpuCallable].call(ctx, [argExpr]); - if (wgsl.isBool(dataType)) { - return snip(`!${argStr}`, bool, 'runtime'); - } - if (wgsl.isNumericSchema(dataType)) { - const resultStr = `!bool(${argStr})`; - const nanGuardedStr = // abstractFloat will be resolved as comptime known value - dataType.type === 'f32' - ? `(((bitcast(${argStr}) & 0x7fffffff) > 0x7f800000) || ${resultStr})` - : dataType.type === 'f16' - ? `(((bitcast(${argStr}) & 0x7fff) > 0x7c00) || ${resultStr})` - : resultStr; - - return snip(nanGuardedStr, bool, 'runtime'); + if (isKnownAtComptime(convArg)) { + return snip(!convArg.value, bool, 'constant'); } - return snip(false, bool, 'constant'); + return snip(`!${convArg.value}`, bool, 'runtime'); }, } satisfies Partial unknown>>; @@ -329,6 +313,11 @@ ${this.ctx.pre}}`; // convert the result. return result; } + + if (wgsl.isBool(expectedType)) { + return bool[$gpuCallable].call(this.ctx, [result]); + } + return tryConvertSnippet(this.ctx, result, expectedType); } finally { this.ctx.expectedType = prevExpectedType; @@ -344,17 +333,12 @@ ${this.ctx.pre}}`; return snip(expression, bool, /* origin */ 'constant'); } - if ( - expression[0] === NODE.logicalExpr || - expression[0] === NODE.binaryExpr || - expression[0] === NODE.assignmentExpr - ) { - // Logical/Binary/Assignment Expression - const [exprType, lhs, op, rhs] = expression; + if (expression[0] === NODE.logicalExpr) { + const [_, lhs, op, rhs] = expression; const lhsExpr = this._expression(lhs); // Short Circuit Evaluation - if ((op === '||' || op === '&&') && isKnownAtComptime(lhsExpr)) { + if (isKnownAtComptime(lhsExpr)) { const evalRhs = op === '&&' ? lhsExpr.value : !lhsExpr.value; if (!evalRhs) { @@ -363,14 +347,14 @@ ${this.ctx.pre}}`; const rhsExpr = this._expression(rhs); - if (rhsExpr.dataType === UnknownData) { - throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`); - } - if (isKnownAtComptime(rhsExpr)) { return snip(!!rhsExpr.value, bool, 'constant'); } + if (rhsExpr.dataType === UnknownData) { + throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`); + } + // we can skip lhs const convRhs = tryConvertSnippet(this.ctx, rhsExpr, bool, false); const rhsStr = this.ctx.resolve(convRhs.value, convRhs.dataType).value; @@ -379,6 +363,32 @@ ${this.ctx.pre}}`; const rhsExpr = this._expression(rhs); + // they are not known at comptime + if (lhsExpr.dataType === UnknownData) { + throw new WgslTypeError(`Left-hand side of '${op}' is of unknown type`); + } + + if (!isKnownAtComptime(rhsExpr) && rhsExpr.dataType === UnknownData) { + throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`); + } + + const [convLhs, convRhs] = convertToCommonType(this.ctx, [lhsExpr, rhsExpr], [bool]) ?? [ + lhsExpr, + rhsExpr, + ]; + + const lhsStr = this.ctx.resolve(convLhs.value, convLhs.dataType).value; + const rhsStr = this.ctx.resolve(convRhs.value, convRhs.dataType).value; + + return snip(`(${lhsStr} ${op} ${rhsStr})`, bool, /* origin */ 'runtime'); + } + + if (expression[0] === NODE.binaryExpr || expression[0] === NODE.assignmentExpr) { + // Binary/Assignment Expression + const [exprType, lhs, op, rhs] = expression; + const lhsExpr = this._expression(lhs); + const rhsExpr = this._expression(rhs); + if (rhsExpr.value instanceof RefOperator) { throw new WgslTypeError( stitch`Cannot assign a ref to an existing variable '${lhsExpr}', define a new variable instead.`, diff --git a/packages/typegpu/tests/tgsl/typeInference.test.ts b/packages/typegpu/tests/tgsl/typeInference.test.ts index 10c6c81726..55d6df4a0c 100644 --- a/packages/typegpu/tests/tgsl/typeInference.test.ts +++ b/packages/typegpu/tests/tgsl/typeInference.test.ts @@ -271,7 +271,7 @@ describe('wgsl generator type inference', () => { `); }); - it('throws when if condition is not boolean', () => { + it('converts if condition to boolean', () => { const myFn = tgpu.fn( [], d.bool, @@ -282,14 +282,17 @@ describe('wgsl generator type inference', () => { return false; }); - expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` - [Error: Resolution of the following tree failed: - - - - fn:myFn: Cannot convert value of type 'vec2' to any of the target types: [bool]] + expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(` + "fn myFn() -> bool { + { + return true; + } + return false; + }" `); }); - it('throws when while condition is not boolean', () => { + it('converts while condition to boolean', () => { const myFn = tgpu.fn( [], d.bool, @@ -300,14 +303,17 @@ describe('wgsl generator type inference', () => { return false; }); - expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` - [Error: Resolution of the following tree failed: - - - - fn:myFn: Cannot convert value of type 'mat2x2f' to any of the target types: [bool]] + expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(` + "fn myFn() -> bool { + while (true) { + return true; + } + return false; + }" `); }); - it('throws when for condition is not boolean', () => { + it('converts for condition to boolean', () => { const myFn = tgpu.fn( [], d.bool, @@ -318,10 +324,13 @@ describe('wgsl generator type inference', () => { return false; }); - expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` - [Error: Resolution of the following tree failed: - - - - fn:myFn: Cannot convert value of type 'abstractInt' to any of the target types: [bool]] + expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(` + "fn myFn() -> bool { + for (var i = 0; true; (i < 10i)) { + return true; + } + return false; + }" `); }); diff --git a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts index 62c96fe428..25a1c0f7c9 100644 --- a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts +++ b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts @@ -1,5 +1,5 @@ import * as tinyest from 'tinyest'; -import { beforeEach, describe, expect, vi } from 'vitest'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; import { namespace } from '../../src/core/resolve/namespace.ts'; import * as d from '../../src/data/index.ts'; import { abstractFloat, abstractInt } from '../../src/data/numeric.ts'; @@ -2137,9 +2137,7 @@ describe('wgslGenerator', () => { }); expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` - "@group(0) @binding(0) var buffer: mat4x4f; - - fn testFn(v: vec3f, a: atomic, p: ptr) { + "fn testFn(v: vec3f, a: atomic, p: ptr) { const _b0 = false; const _b1 = false; const _b2 = false; @@ -2234,6 +2232,262 @@ describe('wgslGenerator', () => { `); }); + describe('handles truthiness check', () => { + it('boolean runtime-known operand', () => { + const testFn = tgpu.fn( + [d.bool], + d.i32, + )((b) => { + let res = -1; + if (b) { + res = 1; + } + return res; + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn(b: bool) -> i32 { + var res = -1; + if (b) { + res = 1i; + } + return res; + }" + `); + }); + + it('numeric runtime-known operand', () => { + const testFn = tgpu.fn( + [d.i32, d.f32, d.f16], + d.i32, + )((n, f, h) => { + let res = -1; + if (n) { + res = 1; + } + if (f) { + res = 2; + } + if (h) { + res = 3; + } + return res; + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn(n: i32, f: f32, h: f16) -> i32 { + var res = -1; + if (bool(n)) { + res = 1i; + } + if ((((bitcast(f) << 1u) - 1u) < 0xff000000)) { + res = 2i; + } + if ((((bitcast(f32(h)) << 1u) - 1u) < 0xff000000)) { + res = 3i; + } + return res; + }" + `); + }); + + it('non-primitive values', ({ root }) => { + const buffer = root.createUniform(d.mat4x4f); + const testFn = tgpu.fn( + [d.vec3f, d.atomic(d.u32), d.ptrPrivate(d.u32)], + d.i32, + )((v, a, p) => { + let res = -1; + if (buffer) { + res = 0; + } + if (buffer.$) { + res = 1; + } + if (v) { + res = 2; + } + if (p) { + res = 5; + } + if (p.$) { + res = 6; + } + return res; + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn(v: vec3f, a: atomic, p: ptr) -> i32 { + var res = -1; + { + res = 0i; + } + { + res = 1i; + } + { + res = 2i; + } + { + res = 5i; + } + if (bool((*p))) { + res = 6i; + } + return res; + }" + `); + }); + + it('atomic', () => { + const testFn = tgpu.fn( + [d.atomic(d.u32)], + d.i32, + )((a) => { + let res = -1; + if (a) { + res = 3; + } + if (std.atomicLoad(a)) { + res = 4; + } + return res; + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn(a: atomic) -> i32 { + var res = -1; + { + res = 3i; + } + if (bool(atomicLoad(&a))) { + res = 4i; + } + return res; + }" + `); + }); + + it('primitive comptime-known operands', () => { + const getTruthy = tgpu.comptime(() => 1882); + const getFalsy = tgpu.comptime(() => 0); + + const f = () => { + 'use gpu'; + let res = -1; + if (getTruthy()) { + res = 1; + } + if (getFalsy()) { + res = 2; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 1i; + } + return res; + }" + `); + }); + + it('operands from slots and accessors', () => { + const Boid = d.struct({ + pos: d.vec2f, + vel: d.vec2f, + }); + + const slot = tgpu.slot>({ pos: d.vec2f(), vel: d.vec2f() }); + const accessor = tgpu.accessor(d.vec4u, d.vec4u(1, 8, 8, 2)); + + const f = () => { + 'use gpu'; + let res = -1; + if (slot.$) { + res = 1; + } + if (accessor.$) { + res = 2; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 1i; + } + { + res = 2i; + } + return res; + }" + `); + }); + + it('complex comptime-known operand', () => { + const slotEmpty = tgpu.slot<{ a?: number }>({}); + const slotFull = tgpu.slot<{ a?: number }>({ a: 42 }); + + const f = () => { + 'use gpu'; + let res = -1; + if (slotEmpty.$.a) { + res = 1; + } + if (slotFull.$.a) { + res = 2; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 2i; + } + return res; + }" + `); + }); + + it('operand of && and ||', () => { + const testFn = tgpu.fn( + [d.i32, d.bool], + d.i32, + )((n, b) => { + let res = 0; + if (b && n) { + res = 1; + } + if (n || b) { + res = 2; + } + + return res; + }); + + expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(` + "fn testFn(n: i32, b: bool) -> i32 { + var res = 0; + if ((b && bool(n))) { + res = 1i; + } + if ((bool(n) || b)) { + res = 2i; + } + return res; + }" + `); + }); + }); + describe('short-circuit evaluation', () => { const state = { counter: 0,