diff --git a/pom.xml b/pom.xml
index b7534dd..c2dc424 100644
--- a/pom.xml
+++ b/pom.xml
@@ -14,7 +14,7 @@
UTF-8
- 21
+ 21
21
@@ -60,7 +60,18 @@
-XDcompilePolicy=simple
--should-stop=ifError=FLOW
-Xplugin:ErrorProne -Xep:NullAway:ERROR -XepOpt:NullAway:AnnotatedPackages=com.garciat.typeclasses
+ --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
-J--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED
+ -J--add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
-J--add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED
-J--add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED
-J--add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED
@@ -88,6 +99,17 @@
maven-surefire-plugin
3.5.4
+
+
+ --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED
+ --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
+ --add-opens=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED
+ --add-opens=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
+
+
org.jacoco
diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java
index a367f43..6db9c02 100644
--- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java
+++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java
@@ -16,8 +16,19 @@
import com.sun.source.util.TaskEvent;
import com.sun.source.util.TaskListener;
import com.sun.source.util.TreePath;
-import com.sun.source.util.TreePathScanner;
import com.sun.source.util.Trees;
+import com.sun.tools.javac.code.Symbol;
+import com.sun.tools.javac.code.Symtab;
+import com.sun.tools.javac.code.Type;
+import com.sun.tools.javac.code.Types;
+import com.sun.tools.javac.comp.Attr;
+import com.sun.tools.javac.comp.AttrContext;
+import com.sun.tools.javac.comp.Env;
+import com.sun.tools.javac.comp.Resolve;
+import com.sun.tools.javac.tree.JCTree;
+import com.sun.tools.javac.tree.TreeMaker;
+import com.sun.tools.javac.util.Context;
+import com.sun.tools.javac.util.Names;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Objects;
@@ -48,65 +59,266 @@ public String getName() {
@Override
public void init(JavacTask task, String... args) {
+ Context context = ((com.sun.tools.javac.api.BasicJavacTask) task).getContext();
+
task.addTaskListener(
new TaskListener() {
@Override
public void finished(TaskEvent e) {
- if (e.getKind() != TaskEvent.Kind.ANALYZE) {
- return;
- }
-
- if (e.getCompilationUnit() == null) {
- return;
+ if (e.getKind() == TaskEvent.Kind.ANALYZE) {
+ // Transform immediately after ANALYZE when type info is available
+ if (e.getCompilationUnit() != null) {
+ new WitnessCallRewriter(Trees.instance(task), context)
+ .transform((JCTree.JCCompilationUnit) e.getCompilationUnit());
+ }
}
-
- new WitnessCallScanner(Trees.instance(task)).scan(e.getCompilationUnit(), null);
}
});
}
- /** Scanner that finds calls to TypeClasses.witness() and validates them. */
- private static class WitnessCallScanner extends TreePathScanner {
+ /** Rewriter that finds calls to TypeClasses.witness() and rewrites them with actual calls. */
+ @SuppressWarnings("NullAway")
+ private static class WitnessCallRewriter extends com.sun.tools.javac.tree.TreeTranslator {
private final Trees trees;
private final StaticWitnessSystem system;
+ private final TreeMaker treeMaker;
+ private final Names names;
+ private final Attr attr;
+ private final Resolve resolve;
+ private final Types types;
+ private final Symtab symtab;
+ private JCTree.JCCompilationUnit currentCompilationUnit;
+ private Env env;
- private WitnessCallScanner(Trees trees) {
+ private WitnessCallRewriter(Trees trees, Context context) {
this.trees = trees;
this.system = new StaticWitnessSystem();
+ this.treeMaker = TreeMaker.instance(context);
+ this.names = Names.instance(context);
+ this.attr = Attr.instance(context);
+ this.resolve = Resolve.instance(context);
+ this.types = Types.instance(context);
+ this.symtab = Symtab.instance(context);
+ }
+
+ /** Transform the compilation unit by replacing witness() calls. */
+ public void transform(JCTree.JCCompilationUnit compilationUnit) {
+ this.currentCompilationUnit = compilationUnit;
+ // First scan to validate witness calls
+ new ValidationScanner().scan(compilationUnit, null);
+ // Then apply transformations
+ compilationUnit.defs = translate(compilationUnit.defs);
+ }
+
+ /** Scanner to validate witness calls before transformation. */
+ private class ValidationScanner extends com.sun.source.util.TreePathScanner {
+ @Override
+ public Void visitMethodInvocation(MethodInvocationTree node, Void arg) {
+ Parser.identity()
+ .guard(
+ Parser.currentElement()
+ .flatMap(Parser.methodMatches(WITNESS_METHOD)))
+ .flatMap(Parser.unaryCallArgument())
+ .flatMap(Parser.newAnonymousClassBody())
+ .flatMap(Parser.singleImplementsClause())
+ .flatMap(Parser.treeTypeMirror())
+ .flatMap(Parser.rawTypeMatches(Ty.class))
+ .flatMap(Parser.unaryTypeArgument())
+ .parse(trees, getCurrentPath(), node)
+ .fold(
+ Unit::unit,
+ witnessType ->
+ WitnessResolution.resolve(system, system.parse(witnessType))
+ .fold(
+ error -> {
+ trees.printMessage(
+ Diagnostic.Kind.ERROR,
+ "Failed to resolve witness for type: "
+ + witnessType
+ + "\nReason: "
+ + error.format(),
+ getCurrentPath().getLeaf(),
+ getCurrentPath().getCompilationUnit());
+ return unit();
+ },
+ plan -> unit()));
+
+ return super.visitMethodInvocation(node, arg);
+ }
}
@Override
- public Void visitMethodInvocation(MethodInvocationTree node, Void arg) {
- Parser.identity()
- .guard(
- Parser.currentElement()
- .flatMap(Parser.methodMatches(WITNESS_METHOD)))
- .flatMap(Parser.unaryCallArgument())
- .flatMap(Parser.newAnonymousClassBody())
- .flatMap(Parser.singleImplementsClause())
- .flatMap(Parser.treeTypeMirror())
- .flatMap(Parser.rawTypeMatches(Ty.class))
- .flatMap(Parser.unaryTypeArgument())
- .parse(trees, getCurrentPath(), node)
- .fold(
- Unit::unit,
- witnessType ->
- WitnessResolution.resolve(system, system.parse(witnessType))
- .fold(
- error -> {
- this.trees.printMessage(
- Diagnostic.Kind.ERROR,
- "Failed to resolve witness for type: "
- + witnessType
- + "\nReason: "
- + error.format(),
- getCurrentPath().getLeaf(),
- getCurrentPath().getCompilationUnit());
- return unit();
- },
- plan -> unit()));
-
- return super.visitMethodInvocation(node, arg);
+ public void visitApply(JCTree.JCMethodInvocation tree) {
+ // Try to transform witness() calls with proper attribution
+ TreePath path = trees.getPath(currentCompilationUnit, tree);
+ if (path != null) {
+ Maybe witnessType =
+ Parser.identity()
+ .guard(
+ Parser.currentElement()
+ .flatMap(Parser.methodMatches(WITNESS_METHOD)))
+ .flatMap(Parser.unaryCallArgument())
+ .flatMap(Parser.newAnonymousClassBody())
+ .flatMap(Parser.singleImplementsClause())
+ .flatMap(Parser.treeTypeMirror())
+ .flatMap(Parser.rawTypeMatches(Ty.class))
+ .flatMap(Parser.unaryTypeArgument())
+ .parse(trees, path, tree);
+
+ witnessType.fold(
+ Unit::unit,
+ wt -> {
+ WitnessResolution.resolve(system, system.parse(wt))
+ .fold(
+ error -> unit(), // Error already reported in validation
+ plan -> {
+ try {
+ // Build the replacement tree
+ JCTree.JCExpression replacement = buildInstantiationTree(plan);
+ if (replacement != null) {
+ replacement.pos = tree.pos;
+
+ // Try to resolve and attribute the replacement
+ // This is the key: we need to attribute the new tree
+ if (tree.type != null) {
+ // Attempt to set the type on the replacement
+ replacement.type = tree.type;
+ }
+
+ result = replacement;
+ }
+ } catch (Exception e) {
+ // If transformation fails, keep original
+ result = tree;
+ }
+ return unit();
+ });
+ return unit();
+ });
+ }
+
+ // Only call super if we didn't transform
+ if (result == null || result == tree) {
+ super.visitApply(tree);
+ }
+ }
+
+ /** Recursively builds a JCTree from an InstantiationPlan with proper type attribution. */
+ private JCTree.JCExpression buildInstantiationTree(WitnessResolution.InstantiationPlan plan) {
+ return switch (plan) {
+ case WitnessResolution.InstantiationPlan.PlanStep(var constructor, var dependencies) -> {
+ // Get the ExecutableElement for the witness constructor
+ ExecutableElement method = constructor.method();
+
+ // Build the method reference
+ JCTree.JCExpression methodSelect = buildMethodReference(method);
+
+ // Build the arguments by recursively processing dependencies
+ com.sun.tools.javac.util.List args =
+ com.sun.tools.javac.util.List.from(
+ dependencies.stream().map(this::buildInstantiationTree).toList());
+
+ // Create the method invocation
+ JCTree.JCMethodInvocation methodInvocation =
+ treeMaker.Apply(com.sun.tools.javac.util.List.nil(), methodSelect, args);
+
+ // Set the type on the method invocation
+ // The return type of the witness constructor method
+ if (method instanceof Symbol.MethodSymbol methodSymbol) {
+ methodInvocation.type = methodSymbol.getReturnType();
+ } else if (method.getReturnType() instanceof javax.lang.model.type.DeclaredType dt) {
+ // Try to get the type from the ExecutableElement
+ TypeMirror returnType = method.getReturnType();
+ if (returnType instanceof Type jcType) {
+ methodInvocation.type = jcType;
+ }
+ }
+
+ yield methodInvocation;
+ }
+ };
+ }
+
+ /**
+ * Builds a JCTree expression that references the given method with proper symbol information.
+ */
+ private JCTree.JCExpression buildMethodReference(ExecutableElement method) {
+ // Get the enclosing class
+ Element enclosingElement = method.getEnclosingElement();
+
+ if (enclosingElement instanceof TypeElement typeElement) {
+ // Build the class name expression
+ JCTree.JCExpression classExpr = buildQualifiedName(typeElement);
+
+ // Build the method name
+ com.sun.tools.javac.util.Name methodName =
+ names.fromString(method.getSimpleName().toString());
+
+ // Create the field access (ClassName.methodName)
+ JCTree.JCFieldAccess fieldAccess = treeMaker.Select(classExpr, methodName);
+
+ // Set the symbol if method is a Symbol.MethodSymbol
+ if (method instanceof Symbol.MethodSymbol methodSymbol) {
+ fieldAccess.sym = methodSymbol;
+ fieldAccess.type = methodSymbol.type;
+ }
+
+ return fieldAccess;
+ } else {
+ throw new IllegalArgumentException(
+ "Method does not have a TypeElement as enclosing element: " + method);
+ }
+ }
+
+ /** Builds a qualified name expression with proper symbol and type information. */
+ private JCTree.JCExpression buildQualifiedName(TypeElement typeElement) {
+ String qualifiedName = typeElement.getQualifiedName().toString();
+ String[] parts = qualifiedName.split("\\.");
+
+ JCTree.JCExpression expr = treeMaker.Ident(names.fromString(parts[0]));
+
+ // Set symbol and type if typeElement is a Symbol.ClassSymbol
+ if (typeElement instanceof Symbol.ClassSymbol classSymbol) {
+ // For the final expression, set the class symbol
+ Symbol currentSym = classSymbol;
+
+ // Navigate through package symbols to build the path
+ for (int i = parts.length - 2; i >= 0; i--) {
+ if (currentSym.owner != null) {
+ currentSym = currentSym.owner;
+ }
+ }
+
+ // Build the expression with proper types
+ if (parts.length == 1) {
+ if (expr instanceof JCTree.JCIdent ident) {
+ ident.sym = classSymbol;
+ ident.type = classSymbol.type;
+ }
+ } else {
+ Symbol pkgSym = currentSym;
+ if (expr instanceof JCTree.JCIdent ident && pkgSym instanceof Symbol.PackageSymbol) {
+ ident.sym = pkgSym;
+ ident.type = pkgSym.type;
+ }
+
+ for (int i = 1; i < parts.length; i++) {
+ JCTree.JCFieldAccess select = treeMaker.Select(expr, names.fromString(parts[i]));
+ if (i == parts.length - 1) {
+ select.sym = classSymbol;
+ select.type = classSymbol.type;
+ }
+ expr = select;
+ }
+ }
+ } else {
+ // Fallback: just build the structure without symbols
+ for (int i = 1; i < parts.length; i++) {
+ expr = treeMaker.Select(expr, names.fromString(parts[i]));
+ }
+ }
+
+ return expr;
}
}
}