diff --git a/pom.xml b/pom.xml index b7534dd..c2dc424 100644 --- a/pom.xml +++ b/pom.xml @@ -14,7 +14,7 @@ UTF-8 - 21 + 21 21 @@ -60,7 +60,18 @@ -XDcompilePolicy=simple --should-stop=ifError=FLOW -Xplugin:ErrorProne -Xep:NullAway:ERROR -XepOpt:NullAway:AnnotatedPackages=com.garciat.typeclasses + --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED -J--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED + -J--add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED -J--add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED -J--add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED -J--add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED @@ -88,6 +99,17 @@ maven-surefire-plugin 3.5.4 + + + --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED + --add-opens=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + --add-opens=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED + + org.jacoco diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index a367f43..6db9c02 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -16,8 +16,19 @@ import com.sun.source.util.TaskEvent; import com.sun.source.util.TaskListener; import com.sun.source.util.TreePath; -import com.sun.source.util.TreePathScanner; import com.sun.source.util.Trees; +import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Symtab; +import com.sun.tools.javac.code.Type; +import com.sun.tools.javac.code.Types; +import com.sun.tools.javac.comp.Attr; +import com.sun.tools.javac.comp.AttrContext; +import com.sun.tools.javac.comp.Env; +import com.sun.tools.javac.comp.Resolve; +import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.TreeMaker; +import com.sun.tools.javac.util.Context; +import com.sun.tools.javac.util.Names; import java.lang.reflect.Method; import java.util.List; import java.util.Objects; @@ -48,65 +59,266 @@ public String getName() { @Override public void init(JavacTask task, String... args) { + Context context = ((com.sun.tools.javac.api.BasicJavacTask) task).getContext(); + task.addTaskListener( new TaskListener() { @Override public void finished(TaskEvent e) { - if (e.getKind() != TaskEvent.Kind.ANALYZE) { - return; - } - - if (e.getCompilationUnit() == null) { - return; + if (e.getKind() == TaskEvent.Kind.ANALYZE) { + // Transform immediately after ANALYZE when type info is available + if (e.getCompilationUnit() != null) { + new WitnessCallRewriter(Trees.instance(task), context) + .transform((JCTree.JCCompilationUnit) e.getCompilationUnit()); + } } - - new WitnessCallScanner(Trees.instance(task)).scan(e.getCompilationUnit(), null); } }); } - /** Scanner that finds calls to TypeClasses.witness() and validates them. */ - private static class WitnessCallScanner extends TreePathScanner { + /** Rewriter that finds calls to TypeClasses.witness() and rewrites them with actual calls. */ + @SuppressWarnings("NullAway") + private static class WitnessCallRewriter extends com.sun.tools.javac.tree.TreeTranslator { private final Trees trees; private final StaticWitnessSystem system; + private final TreeMaker treeMaker; + private final Names names; + private final Attr attr; + private final Resolve resolve; + private final Types types; + private final Symtab symtab; + private JCTree.JCCompilationUnit currentCompilationUnit; + private Env env; - private WitnessCallScanner(Trees trees) { + private WitnessCallRewriter(Trees trees, Context context) { this.trees = trees; this.system = new StaticWitnessSystem(); + this.treeMaker = TreeMaker.instance(context); + this.names = Names.instance(context); + this.attr = Attr.instance(context); + this.resolve = Resolve.instance(context); + this.types = Types.instance(context); + this.symtab = Symtab.instance(context); + } + + /** Transform the compilation unit by replacing witness() calls. */ + public void transform(JCTree.JCCompilationUnit compilationUnit) { + this.currentCompilationUnit = compilationUnit; + // First scan to validate witness calls + new ValidationScanner().scan(compilationUnit, null); + // Then apply transformations + compilationUnit.defs = translate(compilationUnit.defs); + } + + /** Scanner to validate witness calls before transformation. */ + private class ValidationScanner extends com.sun.source.util.TreePathScanner { + @Override + public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { + Parser.identity() + .guard( + Parser.currentElement() + .flatMap(Parser.methodMatches(WITNESS_METHOD))) + .flatMap(Parser.unaryCallArgument()) + .flatMap(Parser.newAnonymousClassBody()) + .flatMap(Parser.singleImplementsClause()) + .flatMap(Parser.treeTypeMirror()) + .flatMap(Parser.rawTypeMatches(Ty.class)) + .flatMap(Parser.unaryTypeArgument()) + .parse(trees, getCurrentPath(), node) + .fold( + Unit::unit, + witnessType -> + WitnessResolution.resolve(system, system.parse(witnessType)) + .fold( + error -> { + trees.printMessage( + Diagnostic.Kind.ERROR, + "Failed to resolve witness for type: " + + witnessType + + "\nReason: " + + error.format(), + getCurrentPath().getLeaf(), + getCurrentPath().getCompilationUnit()); + return unit(); + }, + plan -> unit())); + + return super.visitMethodInvocation(node, arg); + } } @Override - public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { - Parser.identity() - .guard( - Parser.currentElement() - .flatMap(Parser.methodMatches(WITNESS_METHOD))) - .flatMap(Parser.unaryCallArgument()) - .flatMap(Parser.newAnonymousClassBody()) - .flatMap(Parser.singleImplementsClause()) - .flatMap(Parser.treeTypeMirror()) - .flatMap(Parser.rawTypeMatches(Ty.class)) - .flatMap(Parser.unaryTypeArgument()) - .parse(trees, getCurrentPath(), node) - .fold( - Unit::unit, - witnessType -> - WitnessResolution.resolve(system, system.parse(witnessType)) - .fold( - error -> { - this.trees.printMessage( - Diagnostic.Kind.ERROR, - "Failed to resolve witness for type: " - + witnessType - + "\nReason: " - + error.format(), - getCurrentPath().getLeaf(), - getCurrentPath().getCompilationUnit()); - return unit(); - }, - plan -> unit())); - - return super.visitMethodInvocation(node, arg); + public void visitApply(JCTree.JCMethodInvocation tree) { + // Try to transform witness() calls with proper attribution + TreePath path = trees.getPath(currentCompilationUnit, tree); + if (path != null) { + Maybe witnessType = + Parser.identity() + .guard( + Parser.currentElement() + .flatMap(Parser.methodMatches(WITNESS_METHOD))) + .flatMap(Parser.unaryCallArgument()) + .flatMap(Parser.newAnonymousClassBody()) + .flatMap(Parser.singleImplementsClause()) + .flatMap(Parser.treeTypeMirror()) + .flatMap(Parser.rawTypeMatches(Ty.class)) + .flatMap(Parser.unaryTypeArgument()) + .parse(trees, path, tree); + + witnessType.fold( + Unit::unit, + wt -> { + WitnessResolution.resolve(system, system.parse(wt)) + .fold( + error -> unit(), // Error already reported in validation + plan -> { + try { + // Build the replacement tree + JCTree.JCExpression replacement = buildInstantiationTree(plan); + if (replacement != null) { + replacement.pos = tree.pos; + + // Try to resolve and attribute the replacement + // This is the key: we need to attribute the new tree + if (tree.type != null) { + // Attempt to set the type on the replacement + replacement.type = tree.type; + } + + result = replacement; + } + } catch (Exception e) { + // If transformation fails, keep original + result = tree; + } + return unit(); + }); + return unit(); + }); + } + + // Only call super if we didn't transform + if (result == null || result == tree) { + super.visitApply(tree); + } + } + + /** Recursively builds a JCTree from an InstantiationPlan with proper type attribution. */ + private JCTree.JCExpression buildInstantiationTree(WitnessResolution.InstantiationPlan plan) { + return switch (plan) { + case WitnessResolution.InstantiationPlan.PlanStep(var constructor, var dependencies) -> { + // Get the ExecutableElement for the witness constructor + ExecutableElement method = constructor.method(); + + // Build the method reference + JCTree.JCExpression methodSelect = buildMethodReference(method); + + // Build the arguments by recursively processing dependencies + com.sun.tools.javac.util.List args = + com.sun.tools.javac.util.List.from( + dependencies.stream().map(this::buildInstantiationTree).toList()); + + // Create the method invocation + JCTree.JCMethodInvocation methodInvocation = + treeMaker.Apply(com.sun.tools.javac.util.List.nil(), methodSelect, args); + + // Set the type on the method invocation + // The return type of the witness constructor method + if (method instanceof Symbol.MethodSymbol methodSymbol) { + methodInvocation.type = methodSymbol.getReturnType(); + } else if (method.getReturnType() instanceof javax.lang.model.type.DeclaredType dt) { + // Try to get the type from the ExecutableElement + TypeMirror returnType = method.getReturnType(); + if (returnType instanceof Type jcType) { + methodInvocation.type = jcType; + } + } + + yield methodInvocation; + } + }; + } + + /** + * Builds a JCTree expression that references the given method with proper symbol information. + */ + private JCTree.JCExpression buildMethodReference(ExecutableElement method) { + // Get the enclosing class + Element enclosingElement = method.getEnclosingElement(); + + if (enclosingElement instanceof TypeElement typeElement) { + // Build the class name expression + JCTree.JCExpression classExpr = buildQualifiedName(typeElement); + + // Build the method name + com.sun.tools.javac.util.Name methodName = + names.fromString(method.getSimpleName().toString()); + + // Create the field access (ClassName.methodName) + JCTree.JCFieldAccess fieldAccess = treeMaker.Select(classExpr, methodName); + + // Set the symbol if method is a Symbol.MethodSymbol + if (method instanceof Symbol.MethodSymbol methodSymbol) { + fieldAccess.sym = methodSymbol; + fieldAccess.type = methodSymbol.type; + } + + return fieldAccess; + } else { + throw new IllegalArgumentException( + "Method does not have a TypeElement as enclosing element: " + method); + } + } + + /** Builds a qualified name expression with proper symbol and type information. */ + private JCTree.JCExpression buildQualifiedName(TypeElement typeElement) { + String qualifiedName = typeElement.getQualifiedName().toString(); + String[] parts = qualifiedName.split("\\."); + + JCTree.JCExpression expr = treeMaker.Ident(names.fromString(parts[0])); + + // Set symbol and type if typeElement is a Symbol.ClassSymbol + if (typeElement instanceof Symbol.ClassSymbol classSymbol) { + // For the final expression, set the class symbol + Symbol currentSym = classSymbol; + + // Navigate through package symbols to build the path + for (int i = parts.length - 2; i >= 0; i--) { + if (currentSym.owner != null) { + currentSym = currentSym.owner; + } + } + + // Build the expression with proper types + if (parts.length == 1) { + if (expr instanceof JCTree.JCIdent ident) { + ident.sym = classSymbol; + ident.type = classSymbol.type; + } + } else { + Symbol pkgSym = currentSym; + if (expr instanceof JCTree.JCIdent ident && pkgSym instanceof Symbol.PackageSymbol) { + ident.sym = pkgSym; + ident.type = pkgSym.type; + } + + for (int i = 1; i < parts.length; i++) { + JCTree.JCFieldAccess select = treeMaker.Select(expr, names.fromString(parts[i])); + if (i == parts.length - 1) { + select.sym = classSymbol; + select.type = classSymbol.type; + } + expr = select; + } + } + } else { + // Fallback: just build the structure without symbols + for (int i = 1; i < parts.length; i++) { + expr = treeMaker.Select(expr, names.fromString(parts[i])); + } + } + + return expr; } } }