Skip to content

Commit d5d8801

Browse files
authored
feat: Overload operators for vectors (#2176)
1 parent 7de6653 commit d5d8801

30 files changed

Lines changed: 1286 additions & 471 deletions

apps/typegpu-docs/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ src/content/docs/api
2626
# tests
2727
tests/artifacts
2828
!tests/artifacts/README.md
29+
30+
# generated transformed files
31+
*.tsnotover.ts
32+
*.tsnotover.tsx

apps/typegpu-docs/package.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"version": "0.0.1",
55
"private": true,
66
"scripts": {
7-
"dev": "astro dev",
8-
"build": "astro check && astro build",
7+
"transform-overloads": "find . -type f -name '*.tsnotover.ts' -delete && node scripts/transform-overloads.ts",
8+
"dev": "pnpm run transform-overloads && astro dev",
9+
"build": "pnpm run transform-overloads && astro check && astro build",
910
"test:types": "astro check",
1011
"preview": "astro preview",
1112
"astro": "astro"
@@ -78,6 +79,7 @@
7879
"@webgpu/types": "catalog:types",
7980
"astro-vtbot": "^2.1.10",
8081
"autoprefixer": "^10.4.21",
82+
"magic-string": "^0.30.21",
8183
"tailwindcss": "^4.1.11",
8284
"tailwindcss-motion": "^1.1.1",
8385
"vite-imagetools": "catalog:frontend",
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import ts from 'typescript';
2+
import MagicString from 'magic-string';
3+
import { readdir } from 'fs/promises';
4+
import { basename, dirname, extname, join, relative } from 'path';
5+
import { fileURLToPath } from 'url';
6+
import { writeFile } from 'fs/promises';
7+
8+
const __dirname = dirname(fileURLToPath(import.meta.url));
9+
const projectRoot = join(__dirname, '..');
10+
const examplesDir = join(projectRoot, 'src', 'examples');
11+
12+
const operatorToMethod: Record<string, string> = {
13+
[ts.SyntaxKind.PlusToken]: 'add',
14+
[ts.SyntaxKind.PlusEqualsToken]: 'add',
15+
[ts.SyntaxKind.MinusToken]: 'sub',
16+
[ts.SyntaxKind.MinusEqualsToken]: 'sub',
17+
[ts.SyntaxKind.AsteriskToken]: 'mul',
18+
[ts.SyntaxKind.AsteriskEqualsToken]: 'mul',
19+
[ts.SyntaxKind.SlashToken]: 'div',
20+
[ts.SyntaxKind.SlashEqualsToken]: 'div',
21+
[ts.SyntaxKind.AsteriskAsteriskToken]: 'pow',
22+
[ts.SyntaxKind.AsteriskAsteriskEqualsToken]: 'pow',
23+
};
24+
25+
const assignmentOperators = [
26+
ts.SyntaxKind.PlusEqualsToken,
27+
ts.SyntaxKind.MinusEqualsToken,
28+
ts.SyntaxKind.AsteriskEqualsToken,
29+
ts.SyntaxKind.SlashEqualsToken,
30+
];
31+
32+
const commutativeMethods = ['add', 'mul'];
33+
34+
async function findTypeScriptFiles(dir: string): Promise<string[]> {
35+
const files: string[] = [];
36+
37+
async function walk(currentDir: string): Promise<void> {
38+
const entries = await readdir(currentDir, { withFileTypes: true });
39+
40+
for (const entry of entries) {
41+
const fullPath = join(currentDir, entry.name);
42+
43+
if (entry.isDirectory()) {
44+
await walk(fullPath);
45+
} else if (entry.isFile()) {
46+
const ext = extname(entry.name);
47+
if (
48+
(ext === '.ts' || ext === '.tsx') &&
49+
!entry.name.endsWith('.d.ts') &&
50+
!entry.name.endsWith('.d.tsx') &&
51+
!entry.name.endsWith('.tsnotover.ts') &&
52+
!entry.name.endsWith('.tsnotover.tsx')
53+
) {
54+
files.push(fullPath);
55+
}
56+
}
57+
}
58+
}
59+
60+
await walk(dir);
61+
return files;
62+
}
63+
64+
type Pattern =
65+
| 'left.op(right)' // e.g. vec + 2 => vec.add(2)
66+
| 'right.op(left)' // e.g. 2 * vec => vec.mul(2)
67+
| 'std.op(left, right)'; // e.g. 2 / vec => std.div(2, vec)
68+
69+
function getOverloadPattern(
70+
checker: ts.TypeChecker,
71+
node: ts.BinaryExpression,
72+
): Pattern | undefined {
73+
const methodName = operatorToMethod[node.operatorToken.kind];
74+
if (!methodName) {
75+
// Not overlaoded
76+
return undefined;
77+
}
78+
79+
// Get the types of both operands
80+
const leftType = checker.getTypeAtLocation(node.left);
81+
const rightType = checker.getTypeAtLocation(node.right);
82+
83+
if (
84+
!checker.__tsover__couldHaveOverloadedOperators(
85+
node.left,
86+
node.operatorToken.kind,
87+
node.right,
88+
leftType,
89+
rightType,
90+
)
91+
) {
92+
// Not overlaoded
93+
return undefined;
94+
}
95+
96+
// For non-commutative operators, use the standard library function
97+
if (!commutativeMethods.includes(methodName)) {
98+
return 'std.op(left, right)';
99+
}
100+
101+
// Since other supported operators are commutative, prefer left method, fall back to right
102+
const leftHasMethod = leftType.getProperty(methodName) !== undefined;
103+
104+
return leftHasMethod ? 'left.op(right)' : 'right.op(left)';
105+
}
106+
107+
function createProgram(allFiles: string[]): ts.Program {
108+
const configPath = join(projectRoot, 'tsconfig.json');
109+
const configText = ts.sys.readFile(configPath);
110+
111+
if (!configText) {
112+
throw new Error(`Could not read tsconfig.json at ${configPath}`);
113+
}
114+
115+
const { config } = ts.parseConfigFileTextToJson(configPath, configText);
116+
const parsedConfig = ts.parseJsonConfigFileContent(
117+
config,
118+
ts.sys,
119+
projectRoot,
120+
);
121+
122+
const compilerOptions: ts.CompilerOptions = {
123+
...parsedConfig.options,
124+
noEmit: true,
125+
};
126+
127+
const host = ts.createCompilerHost(compilerOptions, true);
128+
129+
return ts.createProgram(allFiles, compilerOptions, host);
130+
}
131+
132+
function isStdDeclared(sourceFile: ts.SourceFile): boolean {
133+
for (const stmt of sourceFile.statements) {
134+
if (!ts.isImportDeclaration(stmt)) {
135+
continue;
136+
}
137+
const moduleSpecifier = stmt.moduleSpecifier;
138+
if (!ts.isStringLiteral(moduleSpecifier)) {
139+
continue;
140+
}
141+
const namedBindings = stmt.importClause?.namedBindings;
142+
// Case 1: import { std } from 'typegpu'
143+
if (
144+
moduleSpecifier.text === 'typegpu' &&
145+
namedBindings &&
146+
ts.isNamedImports(namedBindings) &&
147+
namedBindings.elements.some((el) => el.name.text === 'std')
148+
) {
149+
return true;
150+
}
151+
// Case 2: import * as std from 'typegpu/std'
152+
if (
153+
moduleSpecifier.text === 'typegpu/std' &&
154+
namedBindings &&
155+
ts.isNamespaceImport(namedBindings) &&
156+
namedBindings.name.text === 'std'
157+
) {
158+
return true;
159+
}
160+
}
161+
return false;
162+
}
163+
164+
function transformFile(
165+
sourceFilePath: string,
166+
program: ts.Program,
167+
): { code: string; hasChanges: boolean } {
168+
const checker = program.getTypeChecker();
169+
const sourceFile = program.getSourceFile(sourceFilePath);
170+
171+
if (!sourceFile) {
172+
throw new Error(`Could not get source file for ${sourceFilePath}`);
173+
}
174+
175+
const sourceText = sourceFile.text;
176+
const magic = new MagicString(sourceText);
177+
let hasChanges = false;
178+
let needsStd = false;
179+
180+
function visit(node: ts.Node): void {
181+
// Visit all children of the node first, then process the node itself
182+
ts.forEachChild(node, visit);
183+
184+
if (!ts.isBinaryExpression(node)) {
185+
return;
186+
}
187+
188+
const pattern = getOverloadPattern(checker, node);
189+
const methodName = operatorToMethod[node.operatorToken.kind];
190+
191+
if (!pattern || !methodName) {
192+
return;
193+
}
194+
195+
hasChanges = true;
196+
197+
const start = node.getStart();
198+
const end = node.getEnd();
199+
const leftStart = node.left.getStart();
200+
const leftEnd = node.left.getEnd();
201+
const rightStart = node.right.getStart();
202+
const rightEnd = node.right.getEnd();
203+
204+
const leftText = magic.slice(leftStart, leftEnd);
205+
const rightText = magic.slice(rightStart, rightEnd);
206+
207+
let replacement = '';
208+
if (pattern === 'std.op(left, right)') {
209+
needsStd = true;
210+
replacement = `std.${methodName}(${leftText}, ${rightText})`;
211+
} else if (pattern === 'left.op(right)') {
212+
replacement = `${leftText}.${methodName}(${rightText})`;
213+
} else if (pattern === 'right.op(left)') {
214+
replacement = `${rightText}.${methodName}(${leftText})`;
215+
} else {
216+
throw new Error(`Unsupported pattern: ${pattern}`);
217+
}
218+
219+
if (assignmentOperators.includes(node.operatorToken.kind)) {
220+
// E.g. transforms a += b into a = a.add(b)
221+
magic.overwrite(start, end, `${leftText} = ${replacement}`);
222+
} else {
223+
magic.overwrite(start, end, replacement);
224+
}
225+
}
226+
227+
visit(sourceFile);
228+
229+
if (needsStd && !isStdDeclared(sourceFile)) {
230+
magic.prepend("import { std } from 'typegpu';\n");
231+
}
232+
233+
return { code: magic.toString(), hasChanges };
234+
}
235+
236+
async function main() {
237+
console.log('Finding TypeScript files in examples directory...');
238+
239+
const allFiles = await findTypeScriptFiles(examplesDir);
240+
241+
console.log(`Found ${allFiles.length} files to process`);
242+
console.log('Creating TypeScript program...');
243+
244+
const program = createProgram(allFiles);
245+
246+
console.log('Transforming files...');
247+
248+
let transformedCount = 0;
249+
let errorCount = 0;
250+
251+
for (const filePath of allFiles) {
252+
try {
253+
const { code, hasChanges } = transformFile(filePath, program);
254+
255+
const ext = filePath.endsWith('.tsx') ? '.tsx' : '.ts';
256+
const baseName = basename(filePath, ext);
257+
const dir = dirname(filePath);
258+
const outputPath = join(dir, `${baseName}.tsnotover${ext}`);
259+
260+
if (hasChanges) {
261+
transformedCount++;
262+
await writeFile(outputPath, code, 'utf-8');
263+
console.log(`Transformed: ${relative(projectRoot, filePath)}`);
264+
}
265+
} catch (error) {
266+
errorCount++;
267+
const errorMessage = error instanceof Error
268+
? error.message
269+
: String(error);
270+
console.error(
271+
`Error processing ${relative(projectRoot, filePath)}: ${errorMessage}`,
272+
);
273+
throw new Error(
274+
`Failed to transform ${
275+
relative(projectRoot, filePath)
276+
}: ${errorMessage}`,
277+
{ cause: error },
278+
);
279+
}
280+
}
281+
282+
console.log(
283+
`\nDone! Transformed ${transformedCount} files, ${errorCount} errors.`,
284+
);
285+
}
286+
287+
main().catch((error) => {
288+
console.error('Fatal error:', error);
289+
process.exit(1);
290+
});

apps/typegpu-docs/src/components/CodeEditor.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ const createCodeEditorComponent = (
8484
>
8585
<Editor
8686
defaultLanguage={language}
87-
value={file.content}
87+
value={file.tsnotoverContent ?? file.content}
8888
path={path}
8989
beforeMount={beforeMount}
9090
onMount={onMount}

apps/typegpu-docs/src/components/stackblitz/openInStackBlitz.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ export const openInStackBlitz = (
5757
const tsFiles: Record<string, string> = {};
5858

5959
for (const file of example.tsFiles) {
60-
tsFiles[`src/${file.path}`] = file.content;
60+
tsFiles[`src/${file.path}`] = file.tsnotoverContent ?? file.content;
6161
}
6262
for (const file of common) {
63-
tsFiles[`src/common/${file.path}`] = file.content;
63+
tsFiles[`src/common/${file.path}`] = file.tsnotoverContent ?? file.content;
6464
}
6565

6666
for (const key of Object.keys(tsFiles)) {

0 commit comments

Comments
 (0)