Skip to content

Commit ace8591

Browse files
authored
fix/feat: bit operation on vectors (#2276)
1 parent 6a8b3ac commit ace8591

9 files changed

Lines changed: 475 additions & 21 deletions

File tree

packages/typegpu/src/data/index.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ assignInfixOperator(MatBase, 'add', Operator.plus);
3636
assignInfixOperator(MatBase, 'sub', Operator.minus);
3737
assignInfixOperator(MatBase, 'mul', Operator.star);
3838

39+
// bitShift does not yet have tsover operator symbol
40+
{
41+
// oxlint-disable-next-line typescript/no-explicit-any -- anything is possible
42+
const proto = VecBase.prototype as any;
43+
proto.bitShiftLeft = function (this: unknown, other: unknown) {
44+
return (infixOperators.bitShiftLeft as (a: unknown, b: unknown) => unknown)(this, other);
45+
};
46+
proto.bitShiftRight = function (this: unknown, other: unknown) {
47+
return (infixOperators.bitShiftRight as (a: unknown, b: unknown) => unknown)(this, other);
48+
};
49+
}
50+
3951
export { bool, f16, f32, i32, u16, u32 } from './numeric.ts';
4052
export {
4153
isAlignAttrib,

packages/typegpu/src/data/vectorOps.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ const binaryComponentWise4i = (op: BinaryOp) => (a: wgsl.v4i, b: wgsl.v4i) =>
117117
const binaryComponentWise4u = (op: BinaryOp) => (a: wgsl.v4u, b: wgsl.v4u) =>
118118
vec4u(op(a.x, b.x), op(a.y, b.y), op(a.z, b.z), op(a.w, b.w));
119119

120+
const binaryComponentWise2i2u = (op: BinaryOp) => (a: wgsl.v2i, b: wgsl.v2u) =>
121+
vec2i(op(a.x, b.x), op(a.y, b.y));
122+
123+
const binaryComponentWise3i3u = (op: BinaryOp) => (a: wgsl.v3i, b: wgsl.v3u) =>
124+
vec3i(op(a.x, b.x), op(a.y, b.y), op(a.z, b.z));
125+
126+
const binaryComponentWise4i4u = (op: BinaryOp) => (a: wgsl.v4i, b: wgsl.v4u) =>
127+
vec4i(op(a.x, b.x), op(a.y, b.y), op(a.z, b.z), op(a.w, b.w));
128+
120129
const binaryComponentWise2x2f = (op: BinaryOp) => (a: wgsl.m2x2f, b: wgsl.m2x2f) => {
121130
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
122131
const b_ = b.columns as [wgsl.v2f, wgsl.v2f];
@@ -1090,6 +1099,34 @@ export const VectorOps = {
10901099
vec4h: unary4h(Math.tanh),
10911100
} as Record<VecKind, <T extends vBase>(v: T) => T>,
10921101

1102+
bitShiftLeft: {
1103+
vec2i: binaryComponentWise2i2u((a, b) => a << b),
1104+
vec2u: binaryComponentWise2u((a, b) => a << b),
1105+
1106+
vec3i: binaryComponentWise3i3u((a, b) => a << b),
1107+
vec3u: binaryComponentWise3u((a, b) => a << b),
1108+
1109+
vec4i: binaryComponentWise4i4u((a, b) => a << b),
1110+
vec4u: binaryComponentWise4u((a, b) => a << b),
1111+
} as Record<
1112+
VecKind,
1113+
<T extends wgsl.AnyIntegerVecInstance, U extends wgsl.AnyUnsignedVecInstance>(a: T, b: U) => T
1114+
>,
1115+
1116+
bitShiftRight: {
1117+
vec2i: binaryComponentWise2i2u((a, b) => a >> b),
1118+
vec2u: binaryComponentWise2u((a, b) => a >> b),
1119+
1120+
vec3i: binaryComponentWise3i3u((a, b) => a >> b),
1121+
vec3u: binaryComponentWise3u((a, b) => a >> b),
1122+
1123+
vec4i: binaryComponentWise4i4u((a, b) => a >> b),
1124+
vec4u: binaryComponentWise4u((a, b) => a >> b),
1125+
} as Record<
1126+
VecKind,
1127+
<T extends wgsl.AnyIntegerVecInstance, U extends wgsl.AnyUnsignedVecInstance>(a: T, b: U) => T
1128+
>,
1129+
10931130
bitcastU32toF32: {
10941131
vec2u: (n: wgsl.v2u) => vec2f(bitcastU32toF32Impl(n.x), bitcastU32toF32Impl(n.y)),
10951132
vec3u: (n: wgsl.v3u) =>

packages/typegpu/src/data/wgslTypes.ts

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ export interface vecInfixNotation<T extends vecBase> {
7070
[Symbol.operatorPercent](lhs: T | number, rhs: T | number): T;
7171
}
7272

73+
export type vecIToVecU<T extends AnyIntegerVecInstance> = T extends v2i | v2u
74+
? v2u
75+
: T extends v3i | v3u
76+
? v3u
77+
: v4u;
78+
79+
export interface vecBitShiftNotation<T extends AnyIntegerVecInstance> {
80+
bitShiftLeft(rhs: vecIToVecU<T> | number): T;
81+
bitShiftRight(rhs: vecIToVecU<T> | number): T;
82+
}
83+
7384
/**
7485
* Matrix infix notation.
7586
*
@@ -228,7 +239,8 @@ export interface v2h extends Tuple2<number>, Swizzle2<v2h, v3h, v4h>, vecInfixNo
228239
* Interface representing its WGSL vector type counterpart: vec2i or vec2<i32>.
229240
* A vector with 2 elements of type i32
230241
*/
231-
export interface v2i extends Tuple2<number>, Swizzle2<v2i, v3i, v4i>, vecInfixNotation<v2i> {
242+
export interface v2i
243+
extends Tuple2<number>, Swizzle2<v2i, v3i, v4i>, vecInfixNotation<v2i>, vecBitShiftNotation<v2i> {
232244
readonly [$internal]: true;
233245
/** use to distinguish between vectors of the same size on the type level */
234246
readonly kind: 'vec2i';
@@ -242,7 +254,8 @@ export interface v2i extends Tuple2<number>, Swizzle2<v2i, v3i, v4i>, vecInfixNo
242254
* Interface representing its WGSL vector type counterpart: vec2u or vec2<u32>.
243255
* A vector with 2 elements of type u32
244256
*/
245-
export interface v2u extends Tuple2<number>, Swizzle2<v2u, v3u, v4u>, vecInfixNotation<v2u> {
257+
export interface v2u
258+
extends Tuple2<number>, Swizzle2<v2u, v3u, v4u>, vecInfixNotation<v2u>, vecBitShiftNotation<v2u> {
246259
readonly [$internal]: true;
247260
/** use to distinguish between vectors of the same size on the type level */
248261
readonly kind: 'vec2u';
@@ -302,7 +315,8 @@ export interface v3h extends Tuple3<number>, Swizzle3<v2h, v3h, v4h>, vecInfixNo
302315
* Interface representing its WGSL vector type counterpart: vec3i or vec3<i32>.
303316
* A vector with 3 elements of type i32
304317
*/
305-
export interface v3i extends Tuple3<number>, Swizzle3<v2i, v3i, v4i>, vecInfixNotation<v3i> {
318+
export interface v3i
319+
extends Tuple3<number>, Swizzle3<v2i, v3i, v4i>, vecInfixNotation<v3i>, vecBitShiftNotation<v3i> {
306320
readonly [$internal]: true;
307321
/** use to distinguish between vectors of the same size on the type level */
308322
readonly kind: 'vec3i';
@@ -318,7 +332,8 @@ export interface v3i extends Tuple3<number>, Swizzle3<v2i, v3i, v4i>, vecInfixNo
318332
* Interface representing its WGSL vector type counterpart: vec3u or vec3<u32>.
319333
* A vector with 3 elements of type u32
320334
*/
321-
export interface v3u extends Tuple3<number>, Swizzle3<v2u, v3u, v4u>, vecInfixNotation<v3u> {
335+
export interface v3u
336+
extends Tuple3<number>, Swizzle3<v2u, v3u, v4u>, vecInfixNotation<v3u>, vecBitShiftNotation<v3u> {
322337
readonly [$internal]: true;
323338
/** use to distinguish between vectors of the same size on the type level */
324339
readonly kind: 'vec3u';
@@ -386,7 +401,8 @@ export interface v4h extends Tuple4<number>, Swizzle4<v2h, v3h, v4h>, vecInfixNo
386401
* Interface representing its WGSL vector type counterpart: vec4i or vec4<i32>.
387402
* A vector with 4 elements of type i32
388403
*/
389-
export interface v4i extends Tuple4<number>, Swizzle4<v2i, v3i, v4i>, vecInfixNotation<v4i> {
404+
export interface v4i
405+
extends Tuple4<number>, Swizzle4<v2i, v3i, v4i>, vecInfixNotation<v4i>, vecBitShiftNotation<v4i> {
390406
readonly [$internal]: true;
391407
/** use to distinguish between vectors of the same size on the type level */
392408
readonly kind: 'vec4i';
@@ -404,7 +420,8 @@ export interface v4i extends Tuple4<number>, Swizzle4<v2i, v3i, v4i>, vecInfixNo
404420
* Interface representing its WGSL vector type counterpart: vec4u or vec4<u32>.
405421
* A vector with 4 elements of type u32
406422
*/
407-
export interface v4u extends Tuple4<number>, Swizzle4<v2u, v3u, v4u>, vecInfixNotation<v4u> {
423+
export interface v4u
424+
extends Tuple4<number>, Swizzle4<v2u, v3u, v4u>, vecInfixNotation<v4u>, vecBitShiftNotation<v4u> {
408425
readonly [$internal]: true;
409426
/** use to distinguish between vectors of the same size on the type level */
410427
readonly kind: 'vec4u';
@@ -1528,12 +1545,18 @@ export function isMat(value: unknown): value is Mat2x2f | Mat3x3f | Mat4x4f {
15281545
return isMat2x2f(value) || isMat3x3f(value) || isMat4x4f(value);
15291546
}
15301547

1531-
export function isFloat32VecInstance(
1532-
element: number | AnyVecInstance | AnyMatInstance,
1533-
): element is AnyFloat32VecInstance {
1548+
export function isFloat32VecInstance(element: unknown): element is AnyFloat32VecInstance {
15341549
return isVecInstance(element) && ['vec2f', 'vec3f', 'vec4f'].includes(element.kind);
15351550
}
15361551

1552+
export function isInteger32VecInstance(value: unknown): value is v2u | v2i | v3u | v3i | v4u | v4i {
1553+
return isVecInstance(value) && /[iu]$/.test(value.kind);
1554+
}
1555+
1556+
export function isUint32VecInstance(value: unknown): value is v2u | v3u | v4u {
1557+
return isVecInstance(value) && /[u]$/.test(value.kind);
1558+
}
1559+
15371560
export function isWgslData(value: unknown): value is AnyWgslData {
15381561
return isMarkedInternal(value) && wgslTypeLiterals.includes((value as AnyWgslData)?.type);
15391562
}

packages/typegpu/src/std/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export {
7070
trunc,
7171
} from './numeric.ts';
7272

73-
export { add, div, mod, mul, neg, sub } from './operators.ts';
73+
export { add, bitShiftLeft, bitShiftRight, div, mod, mul, neg, sub } from './operators.ts';
7474

7575
export { rotateX4, rotateY4, rotateZ4, scale4, translate4 } from './matrix.ts';
7676

packages/typegpu/src/std/operators.ts

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
import { dualImpl } from '../core/function/dualImpl.ts';
22
import { stitch } from '../core/resolve/stitch.ts';
3-
import { abstractFloat, f16, f32 } from '../data/numeric.ts';
4-
import { vecTypeToConstructor } from '../data/vector.ts';
3+
import { abstractFloat, f16, f32, i32, u32 } from '../data/numeric.ts';
4+
import { vec2i, vec2u, vec3i, vec3u, vec4i, vec4u, vecTypeToConstructor } from '../data/vector.ts';
55
import { VectorOps } from '../data/vectorOps.ts';
66
import {
7+
type AnyIntegerVecInstance,
78
type AnyMatInstance,
89
type AnyNumericVecInstance,
910
type BaseData,
11+
type mBaseForVec,
12+
type vBaseForMat,
1013
isFloat32VecInstance,
14+
isInteger32VecInstance,
1115
isMat,
1216
isMatInstance,
17+
isUint32VecInstance,
1318
isVec,
1419
isVecInstance,
15-
type mBaseForVec,
16-
type vBaseForMat,
20+
vecIToVecU,
1721
} from '../data/wgslTypes.ts';
1822
import { SignatureNotSupportedError } from '../errors.ts';
1923
import { unify } from '../tgsl/conversion.ts';
@@ -279,3 +283,111 @@ export const neg = dualImpl({
279283
normalImpl: cpuNeg,
280284
codegenImpl: (_ctx, [arg]) => stitch`-(${arg})`,
281285
});
286+
287+
const anyConcreteInteger = [i32, u32, vec2i, vec3i, vec4i, vec2u, vec3u, vec4u] as BaseData[];
288+
289+
const intVecToUnsignedVec = {
290+
vec2i: vec2u,
291+
vec2u: vec2u,
292+
vec3i: vec3u,
293+
vec3u: vec3u,
294+
vec4i: vec4u,
295+
vec4u: vec4u,
296+
} as const;
297+
298+
const bitShiftSignature = (lhs: BaseData, rhs: BaseData) => {
299+
const lhsUnified = unify([lhs], anyConcreteInteger)?.[0];
300+
if (!lhsUnified) {
301+
throw new SignatureNotSupportedError([lhs], anyConcreteInteger);
302+
}
303+
304+
let rhsType: BaseData;
305+
if (isVec(lhsUnified)) {
306+
const cc = lhsUnified.componentCount;
307+
const vecU = cc === 2 ? vec2u : cc === 3 ? vec3u : vec4u;
308+
const rhsUnified = unify([rhs], [u32, vecU])?.[0];
309+
if (!rhsUnified) {
310+
throw new SignatureNotSupportedError([rhs], [u32, vecU]);
311+
}
312+
rhsType = rhsUnified;
313+
} else {
314+
rhsType = u32;
315+
}
316+
317+
return {
318+
argTypes: [lhsUnified, rhsType],
319+
returnType: lhsUnified,
320+
};
321+
};
322+
323+
function cpuBitShiftLeft(lhs: number, rhs: number): number;
324+
function cpuBitShiftLeft<T extends AnyIntegerVecInstance>(lhs: T, rhs: number): T;
325+
function cpuBitShiftLeft<T extends AnyIntegerVecInstance>(lhs: T, rhs: vecIToVecU<T>): T;
326+
function cpuBitShiftLeft<T extends AnyIntegerVecInstance>(
327+
lhs: number | AnyIntegerVecInstance,
328+
rhs: number | vecIToVecU<T>,
329+
) {
330+
if (typeof lhs === 'number' && typeof rhs === 'number') {
331+
return lhs << rhs;
332+
}
333+
if (isInteger32VecInstance(lhs) && isUint32VecInstance(rhs) && lhs.length == rhs.length) {
334+
return VectorOps.bitShiftLeft[lhs.kind](lhs, rhs);
335+
}
336+
if (isInteger32VecInstance(lhs) && typeof rhs === 'number') {
337+
const rhsVec = intVecToUnsignedVec[lhs.kind](rhs);
338+
return VectorOps.bitShiftLeft[lhs.kind](lhs, rhsVec);
339+
}
340+
throw new Error(
341+
'bitShiftLeft called with invalid arguments, expected types: number or integer vector (rhs must be the same arity as lhs).',
342+
);
343+
}
344+
345+
export const bitShiftLeft = dualImpl({
346+
name: 'bitShiftLeft',
347+
signature: bitShiftSignature,
348+
normalImpl: cpuBitShiftLeft,
349+
codegenImpl: (_ctx, [lhs, rhs]) => {
350+
if (isVec(lhs.dataType) && !isVec(rhs.dataType)) {
351+
const cc = lhs.dataType.componentCount;
352+
const schema = cc === 2 ? 'vec2u' : cc === 3 ? 'vec3u' : 'vec4u';
353+
return stitch`(${lhs} << ${schema}(${rhs}))`;
354+
}
355+
return stitch`(${lhs} << ${rhs})`;
356+
},
357+
});
358+
359+
function cpuBitShiftRight(lhs: number, rhs: number): number;
360+
function cpuBitShiftRight<T extends AnyIntegerVecInstance>(lhs: T, rhs: number): T;
361+
function cpuBitShiftRight<T extends AnyIntegerVecInstance>(lhs: T, rhs: vecIToVecU<T>): T;
362+
function cpuBitShiftRight<T extends AnyIntegerVecInstance>(
363+
lhs: number | AnyIntegerVecInstance,
364+
rhs: number | vecIToVecU<T>,
365+
) {
366+
if (typeof lhs === 'number' && typeof rhs === 'number') {
367+
return lhs >> rhs;
368+
}
369+
if (isInteger32VecInstance(lhs) && isUint32VecInstance(rhs) && lhs.length == rhs.length) {
370+
return VectorOps.bitShiftRight[lhs.kind](lhs, rhs);
371+
}
372+
if (isInteger32VecInstance(lhs) && typeof rhs === 'number') {
373+
const rhsVec = intVecToUnsignedVec[lhs.kind](rhs);
374+
return VectorOps.bitShiftRight[lhs.kind](lhs, rhsVec);
375+
}
376+
throw new Error(
377+
'bitShiftRight called with invalid arguments, expected types: number or integer vector (rhs must be the same arity as lhs).',
378+
);
379+
}
380+
381+
export const bitShiftRight = dualImpl({
382+
name: 'bitShiftRight',
383+
signature: bitShiftSignature,
384+
normalImpl: cpuBitShiftRight,
385+
codegenImpl: (_ctx, [lhs, rhs]) => {
386+
if (isVec(lhs.dataType) && !isVec(rhs.dataType)) {
387+
const cc = lhs.dataType.componentCount;
388+
const schema = cc === 2 ? 'vec2u' : cc === 3 ? 'vec3u' : 'vec4u';
389+
return stitch`(${lhs} >> ${schema}(${rhs}))`;
390+
}
391+
return stitch`(${lhs} >> ${rhs})`;
392+
},
393+
});

packages/typegpu/src/tgsl/accessProp.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import {
3737
isWgslStruct,
3838
} from '../data/wgslTypes.ts';
3939
import { $gpuCallable } from '../shared/symbols.ts';
40-
import { add, div, mod, mul, sub } from '../std/operators.ts';
40+
import { add, bitShiftLeft, bitShiftRight, div, mod, mul, sub } from '../std/operators.ts';
4141
import { isKnownAtComptime } from '../types.ts';
4242
import { coerceToSnippet } from './generationHelpers.ts';
4343

@@ -65,6 +65,8 @@ export const infixOperators = {
6565
mul,
6666
div,
6767
mod,
68+
bitShiftLeft,
69+
bitShiftRight,
6870
} as const;
6971

7072
export type InfixOperator = keyof typeof infixOperators;

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
unptr,
1111
} from '../data/dataTypes.ts';
1212
import { bool, i32, u32 } from '../data/numeric.ts';
13+
import { vec2u, vec3u, vec4u } from '../data/vector.ts';
1314
import {
1415
fallthroughCopyOrigin,
1516
isEphemeralOrigin,
@@ -77,6 +78,8 @@ const parenthesizedOps = [
7778

7879
const binaryLogicalOps = ['&&', '||', '==', '!=', '===', '!==', '<', '<=', '>', '>='];
7980

81+
const bitShiftOps: string[] = ['<<', '>>', '<<=', '>>='];
82+
8083
const OP_MAP = {
8184
//
8285
// binary
@@ -348,12 +351,28 @@ ${this.ctx.pre}}`;
348351
return codegen(this.ctx, [lhsExpr, rhsExpr]);
349352
}
350353

351-
const forcedType = exprType === NODE.assignmentExpr ? [lhsExpr.dataType] : undefined;
354+
let convLhs: Snippet;
355+
let convRhs: Snippet;
352356

353-
const [convLhs, convRhs] = convertToCommonType(this.ctx, [lhsExpr, rhsExpr], forcedType) ?? [
354-
lhsExpr,
355-
rhsExpr,
356-
];
357+
if (bitShiftOps.includes(op)) {
358+
// rhs must be u32 (or vecN<u32> for vector lhs)
359+
let rhsTarget: wgsl.BaseData;
360+
if (wgsl.isVec(lhsExpr.dataType)) {
361+
const cc = lhsExpr.dataType.componentCount;
362+
rhsTarget = cc === 2 ? vec2u : cc === 3 ? vec3u : vec4u;
363+
} else {
364+
rhsTarget = u32;
365+
}
366+
convRhs = tryConvertSnippet(this.ctx, rhsExpr, rhsTarget, false);
367+
// if lhs is not an integer type, the browser will return a descriptive wgsl error
368+
convLhs = lhsExpr;
369+
} else {
370+
const forcedType = exprType === NODE.assignmentExpr ? [lhsExpr.dataType] : undefined;
371+
[convLhs, convRhs] = convertToCommonType(this.ctx, [lhsExpr, rhsExpr], forcedType) ?? [
372+
lhsExpr,
373+
rhsExpr,
374+
];
375+
}
357376

358377
const lhsStr = this.ctx.resolve(convLhs.value, convLhs.dataType).value;
359378
const rhsStr = this.ctx.resolve(convRhs.value, convRhs.dataType).value;

packages/typegpu/tests/examples/individual/fluid-with-atomics.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ describe('fluid with atomics example', () => {
103103
}
104104
if (isWaterDrain(x, y)) {
105105
persistFlags(x, y);
106-
updateCell(x, y, (3 << 24));
106+
updateCell(x, y, (3 << 24u));
107107
return true;
108108
}
109109
if (((((y == 0u) || (y == (size.y - 1u))) || (x == 0u)) || (x == (size.x - 1u)))) {

0 commit comments

Comments
 (0)