Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions packages/typegpu/src/core/function/createCallableSchema.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type MapValueToSnippet, type ResolvedSnippet, snip } from '../../data/snippet.ts';
import { type MapValueToSnippet, snip, type Snippet } from '../../data/snippet.ts';
import { type BaseData, isPtr } from '../../data/wgslTypes.ts';
import { setName } from '../../shared/meta.ts';
import { $gpuCallable } from '../../shared/symbols.ts';
Expand All @@ -12,10 +12,7 @@ interface CallableSchemaOptions<T extends AnyFn> {
readonly name: string;
readonly schema: () => BaseData;
readonly normalImpl: T;
readonly codegenImpl: (
ctx: ResolutionCtx,
args: MapValueToSnippet<Parameters<T>>,
) => ResolvedSnippet;
readonly codegenImpl: (ctx: ResolutionCtx, args: MapValueToSnippet<Parameters<T>>) => Snippet;
readonly argTypes: (
...inArgTypes: MapValueToDataType<Parameters<T>>
) => (BaseData | BaseData[])[];
Expand Down
34 changes: 27 additions & 7 deletions packages/typegpu/src/data/numeric.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { $internal } from '../shared/symbols.ts';
import { isBool, isNumericSchema } from './wgslTypes.ts';
import type { AbstractFloat, AbstractInt, Bool, F16, F32, I32, U16, U32 } from './wgslTypes.ts';
import { callableSchema } from '../core/function/createCallableSchema.ts';
import { snip } from './snippet.ts';

export const abstractInt = {
[$internal]: {},
Expand All @@ -22,16 +24,34 @@ const boolCast = callableSchema({
name: 'bool',
schema: () => bool,
argTypes: (arg) => (arg ? [arg] : []),
normalImpl(v?: number | boolean) {
if (v === undefined) {
return false;
normalImpl(v?: unknown) {
return !!v;
},
codegenImpl: (ctx, [arg]) => {
if (arg === undefined) {
return snip(false, bool, 'constant');
}
if (typeof v === 'boolean') {
return v;

const { dataType } = arg;

if (isBool(dataType)) {
return ctx.gen.typeInstantiation(bool, [arg]);
}
return !!v;

if (isNumericSchema(dataType)) {
const argStr = ctx.resolveSnippet(arg).value;
const resultStr = `bool(${argStr})`;

const nanGuardedStr =
dataType.type === 'f32' || dataType.type === 'f16'
? `(((bitcast<u32>(${dataType.type === 'f16' ? `f32(${argStr})` : argStr}) & 0x7fffffff) <= 0x7f800000) && ${resultStr})`
: resultStr;

return snip(nanGuardedStr, bool, 'runtime');
}

Comment thread
cieplypolar marked this conversation as resolved.
return snip(true, bool, 'constant');
},
codegenImpl: (ctx, args) => ctx.gen.typeInstantiation(bool, args),
});

/**
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/wgslTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ export type mBaseForVec<T extends vecBase> = T extends v2f
* Boolean schema representing a single WGSL bool value.
* Cannot be used inside buffers as it is not host-shareable.
*/
export interface Bool extends BaseData, DualFn<(v?: number | boolean) => boolean> {
export interface Bool extends BaseData, DualFn<(v?: unknown) => boolean> {
readonly type: 'bool';

// Type-tokens, not available at runtime
Expand Down
74 changes: 42 additions & 32 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,29 +162,13 @@ const unaryOpCodeToCodegen = {
throw new Error('The unary operator `!` expects 1 argument, but 0 were provided.');
}

if (isKnownAtComptime(argExpr)) {
return snip(!argExpr.value, bool, 'constant');
}

const { value, dataType } = argExpr;
const argStr = ctx.resolve(value, dataType).value;
const convArg = bool[$gpuCallable].call(ctx, [argExpr]);

if (wgsl.isBool(dataType)) {
return snip(`!${argStr}`, bool, 'runtime');
}
if (wgsl.isNumericSchema(dataType)) {
const resultStr = `!bool(${argStr})`;
const nanGuardedStr = // abstractFloat will be resolved as comptime known value
dataType.type === 'f32'
? `(((bitcast<u32>(${argStr}) & 0x7fffffff) > 0x7f800000) || ${resultStr})`
: dataType.type === 'f16'
? `(((bitcast<u32>(${argStr}) & 0x7fff) > 0x7c00) || ${resultStr})`
: resultStr;

return snip(nanGuardedStr, bool, 'runtime');
if (isKnownAtComptime(convArg)) {
return snip(!convArg.value, bool, 'constant');
}

return snip(false, bool, 'constant');
return snip(`!${convArg.value}`, bool, 'runtime');
},
} satisfies Partial<Record<tinyest.UnaryOperator, (...args: never[]) => unknown>>;

Expand Down Expand Up @@ -329,6 +313,11 @@ ${this.ctx.pre}}`;
// convert the result.
return result;
}

if (wgsl.isBool(expectedType)) {
return bool[$gpuCallable].call(this.ctx, [result]);
}

return tryConvertSnippet(this.ctx, result, expectedType);
} finally {
this.ctx.expectedType = prevExpectedType;
Expand All @@ -344,17 +333,12 @@ ${this.ctx.pre}}`;
return snip(expression, bool, /* origin */ 'constant');
}

if (
expression[0] === NODE.logicalExpr ||
expression[0] === NODE.binaryExpr ||
expression[0] === NODE.assignmentExpr
) {
// Logical/Binary/Assignment Expression
const [exprType, lhs, op, rhs] = expression;
if (expression[0] === NODE.logicalExpr) {
const [_, lhs, op, rhs] = expression;
const lhsExpr = this._expression(lhs);

// Short Circuit Evaluation
if ((op === '||' || op === '&&') && isKnownAtComptime(lhsExpr)) {
if (isKnownAtComptime(lhsExpr)) {
const evalRhs = op === '&&' ? lhsExpr.value : !lhsExpr.value;

if (!evalRhs) {
Expand All @@ -363,14 +347,14 @@ ${this.ctx.pre}}`;

const rhsExpr = this._expression(rhs);

if (rhsExpr.dataType === UnknownData) {
throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`);
}

if (isKnownAtComptime(rhsExpr)) {
return snip(!!rhsExpr.value, bool, 'constant');
}

if (rhsExpr.dataType === UnknownData) {
throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`);
}

// we can skip lhs
const convRhs = tryConvertSnippet(this.ctx, rhsExpr, bool, false);
const rhsStr = this.ctx.resolve(convRhs.value, convRhs.dataType).value;
Expand All @@ -379,6 +363,32 @@ ${this.ctx.pre}}`;

const rhsExpr = this._expression(rhs);

// they are not know at comptime
Comment thread
cieplypolar marked this conversation as resolved.
Outdated
if (lhsExpr.dataType === UnknownData) {
throw new WgslTypeError(`Left-hand side of '${op}' is of unknown type`);
}

if (!isKnownAtComptime(rhsExpr) && rhsExpr.dataType === UnknownData) {
throw new WgslTypeError(`Right-hand side of '${op}' is of unknown type`);
}

const [convLhs, convRhs] = convertToCommonType(this.ctx, [lhsExpr, rhsExpr], [bool]) ?? [
lhsExpr,
rhsExpr,
];

const lhsStr = this.ctx.resolve(convLhs.value, convLhs.dataType).value;
const rhsStr = this.ctx.resolve(convRhs.value, convRhs.dataType).value;

return snip(`(${lhsStr} ${op} ${rhsStr})`, bool, /* origin */ 'runtime');
}

if (expression[0] === NODE.binaryExpr || expression[0] === NODE.assignmentExpr) {
// Binary/Assignment Expression
const [exprType, lhs, op, rhs] = expression;
const lhsExpr = this._expression(lhs);
const rhsExpr = this._expression(rhs);

if (rhsExpr.value instanceof RefOperator) {
throw new WgslTypeError(
stitch`Cannot assign a ref to an existing variable '${lhsExpr}', define a new variable instead.`,
Expand Down
39 changes: 24 additions & 15 deletions packages/typegpu/tests/tgsl/typeInference.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ describe('wgsl generator type inference', () => {
`);
});

it('throws when if condition is not boolean', () => {
it('converts if condition to boolean', () => {
const myFn = tgpu.fn(
[],
d.bool,
Expand All @@ -282,14 +282,17 @@ describe('wgsl generator type inference', () => {
return false;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Cannot convert value of type 'vec2<bool>' to any of the target types: [bool]]
expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn() -> bool {
{
return true;
}
return false;
}"
`);
});

it('throws when while condition is not boolean', () => {
it('converts while condition to boolean', () => {
const myFn = tgpu.fn(
[],
d.bool,
Expand All @@ -300,14 +303,17 @@ describe('wgsl generator type inference', () => {
return false;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Cannot convert value of type 'mat2x2f' to any of the target types: [bool]]
expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn() -> bool {
while (true) {
return true;
}
return false;
}"
`);
});

it('throws when for condition is not boolean', () => {
it('converts for condition to boolean', () => {
const myFn = tgpu.fn(
[],
d.bool,
Expand All @@ -318,10 +324,13 @@ describe('wgsl generator type inference', () => {
return false;
});

expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:myFn: Cannot convert value of type 'abstractInt' to any of the target types: [bool]]
expect(tgpu.resolve([myFn])).toMatchInlineSnapshot(`
"fn myFn() -> bool {
for (var i = 0; true; (i < 10i)) {
return true;
}
return false;
}"
`);
});

Expand Down
Loading
Loading