Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.release>21</maven.compiler.release>
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
</properties>

Expand Down Expand Up @@ -60,7 +60,18 @@
<arg>-XDcompilePolicy=simple</arg>
<arg>--should-stop=ifError=FLOW</arg>
<arg>-Xplugin:ErrorProne -Xep:NullAway:ERROR -XepOpt:NullAway:AnnotatedPackages=com.garciat.typeclasses</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED</arg>
<arg>--add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED</arg>
<arg>-J--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED</arg>
<arg>-J--add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED</arg>
<arg>-J--add-exports=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED</arg>
<arg>-J--add-exports=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED</arg>
<arg>-J--add-exports=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED</arg>
Expand Down Expand Up @@ -88,6 +99,17 @@
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.5.4</version>
<configuration>
<argLine>
--add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED
--add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED
--add-exports=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
--add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED
--add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
--add-opens=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED
--add-opens=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.jacoco</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@
import com.sun.source.util.TaskEvent;
import com.sun.source.util.TaskListener;
import com.sun.source.util.TreePath;
import com.sun.source.util.TreePathScanner;
import com.sun.source.util.Trees;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symtab;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.comp.Attr;
import com.sun.tools.javac.comp.AttrContext;
import com.sun.tools.javac.comp.Env;
import com.sun.tools.javac.comp.Resolve;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.TreeMaker;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.Names;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -48,65 +59,266 @@ public String getName() {

@Override
public void init(JavacTask task, String... args) {
Context context = ((com.sun.tools.javac.api.BasicJavacTask) task).getContext();

task.addTaskListener(
new TaskListener() {
@Override
public void finished(TaskEvent e) {
if (e.getKind() != TaskEvent.Kind.ANALYZE) {
return;
}

if (e.getCompilationUnit() == null) {
return;
if (e.getKind() == TaskEvent.Kind.ANALYZE) {
// Transform immediately after ANALYZE when type info is available
if (e.getCompilationUnit() != null) {
new WitnessCallRewriter(Trees.instance(task), context)
.transform((JCTree.JCCompilationUnit) e.getCompilationUnit());
}
}

new WitnessCallScanner(Trees.instance(task)).scan(e.getCompilationUnit(), null);
}
});
}

/** Scanner that finds calls to TypeClasses.witness() and validates them. */
private static class WitnessCallScanner extends TreePathScanner<Void, Void> {
/** Rewriter that finds calls to TypeClasses.witness() and rewrites them with actual calls. */
@SuppressWarnings("NullAway")
private static class WitnessCallRewriter extends com.sun.tools.javac.tree.TreeTranslator {
private final Trees trees;
private final StaticWitnessSystem system;
private final TreeMaker treeMaker;
private final Names names;
private final Attr attr;
private final Resolve resolve;
private final Types types;
private final Symtab symtab;
private JCTree.JCCompilationUnit currentCompilationUnit;
private Env<AttrContext> env;

private WitnessCallScanner(Trees trees) {
private WitnessCallRewriter(Trees trees, Context context) {
this.trees = trees;
this.system = new StaticWitnessSystem();
this.treeMaker = TreeMaker.instance(context);
this.names = Names.instance(context);
this.attr = Attr.instance(context);
this.resolve = Resolve.instance(context);
this.types = Types.instance(context);
this.symtab = Symtab.instance(context);
}

/** Transform the compilation unit by replacing witness() calls. */
public void transform(JCTree.JCCompilationUnit compilationUnit) {
this.currentCompilationUnit = compilationUnit;
// First scan to validate witness calls
new ValidationScanner().scan(compilationUnit, null);
// Then apply transformations
compilationUnit.defs = translate(compilationUnit.defs);
}

/** Scanner to validate witness calls before transformation. */
private class ValidationScanner extends com.sun.source.util.TreePathScanner<Void, Void> {
@Override
public Void visitMethodInvocation(MethodInvocationTree node, Void arg) {
Parser.<MethodInvocationTree>identity()
.guard(
Parser.<MethodInvocationTree>currentElement()
.flatMap(Parser.methodMatches(WITNESS_METHOD)))
.flatMap(Parser.unaryCallArgument())
.flatMap(Parser.newAnonymousClassBody())
.flatMap(Parser.singleImplementsClause())
.flatMap(Parser.treeTypeMirror())
.flatMap(Parser.rawTypeMatches(Ty.class))
.flatMap(Parser.unaryTypeArgument())
.parse(trees, getCurrentPath(), node)
.fold(
Unit::unit,
witnessType ->
WitnessResolution.resolve(system, system.parse(witnessType))
.fold(
error -> {
trees.printMessage(
Diagnostic.Kind.ERROR,
"Failed to resolve witness for type: "
+ witnessType
+ "\nReason: "
+ error.format(),
getCurrentPath().getLeaf(),
getCurrentPath().getCompilationUnit());
return unit();
},
plan -> unit()));

return super.visitMethodInvocation(node, arg);
}
}

@Override
public Void visitMethodInvocation(MethodInvocationTree node, Void arg) {
Parser.<MethodInvocationTree>identity()
.guard(
Parser.<MethodInvocationTree>currentElement()
.flatMap(Parser.methodMatches(WITNESS_METHOD)))
.flatMap(Parser.unaryCallArgument())
.flatMap(Parser.newAnonymousClassBody())
.flatMap(Parser.singleImplementsClause())
.flatMap(Parser.treeTypeMirror())
.flatMap(Parser.rawTypeMatches(Ty.class))
.flatMap(Parser.unaryTypeArgument())
.parse(trees, getCurrentPath(), node)
.fold(
Unit::unit,
witnessType ->
WitnessResolution.resolve(system, system.parse(witnessType))
.fold(
error -> {
this.trees.printMessage(
Diagnostic.Kind.ERROR,
"Failed to resolve witness for type: "
+ witnessType
+ "\nReason: "
+ error.format(),
getCurrentPath().getLeaf(),
getCurrentPath().getCompilationUnit());
return unit();
},
plan -> unit()));

return super.visitMethodInvocation(node, arg);
public void visitApply(JCTree.JCMethodInvocation tree) {
// Try to transform witness() calls with proper attribution
TreePath path = trees.getPath(currentCompilationUnit, tree);
if (path != null) {
Maybe<TypeMirror> witnessType =
Parser.<MethodInvocationTree>identity()
.guard(
Parser.<MethodInvocationTree>currentElement()
.flatMap(Parser.methodMatches(WITNESS_METHOD)))
.flatMap(Parser.unaryCallArgument())
.flatMap(Parser.newAnonymousClassBody())
.flatMap(Parser.singleImplementsClause())
.flatMap(Parser.treeTypeMirror())
.flatMap(Parser.rawTypeMatches(Ty.class))
.flatMap(Parser.unaryTypeArgument())
.parse(trees, path, tree);

witnessType.fold(
Unit::unit,
wt -> {
WitnessResolution.resolve(system, system.parse(wt))
.fold(
error -> unit(), // Error already reported in validation
plan -> {
try {
// Build the replacement tree
JCTree.JCExpression replacement = buildInstantiationTree(plan);
if (replacement != null) {
replacement.pos = tree.pos;

// Try to resolve and attribute the replacement
// This is the key: we need to attribute the new tree
if (tree.type != null) {
// Attempt to set the type on the replacement
replacement.type = tree.type;
}

result = replacement;
}
} catch (Exception e) {
// If transformation fails, keep original
result = tree;
}
return unit();
});
return unit();
});
}

// Only call super if we didn't transform
if (result == null || result == tree) {
super.visitApply(tree);
}
}

/** Recursively builds a JCTree from an InstantiationPlan with proper type attribution. */
private JCTree.JCExpression buildInstantiationTree(WitnessResolution.InstantiationPlan plan) {
return switch (plan) {
case WitnessResolution.InstantiationPlan.PlanStep(var constructor, var dependencies) -> {
// Get the ExecutableElement for the witness constructor
ExecutableElement method = constructor.method();

// Build the method reference
JCTree.JCExpression methodSelect = buildMethodReference(method);

// Build the arguments by recursively processing dependencies
com.sun.tools.javac.util.List<JCTree.JCExpression> args =
com.sun.tools.javac.util.List.from(
dependencies.stream().map(this::buildInstantiationTree).toList());

// Create the method invocation
JCTree.JCMethodInvocation methodInvocation =
treeMaker.Apply(com.sun.tools.javac.util.List.nil(), methodSelect, args);

// Set the type on the method invocation
// The return type of the witness constructor method
if (method instanceof Symbol.MethodSymbol methodSymbol) {
methodInvocation.type = methodSymbol.getReturnType();
} else if (method.getReturnType() instanceof javax.lang.model.type.DeclaredType dt) {
// Try to get the type from the ExecutableElement
TypeMirror returnType = method.getReturnType();
if (returnType instanceof Type jcType) {
methodInvocation.type = jcType;
}
}

yield methodInvocation;
}
};
}

/**
* Builds a JCTree expression that references the given method with proper symbol information.
*/
private JCTree.JCExpression buildMethodReference(ExecutableElement method) {
// Get the enclosing class
Element enclosingElement = method.getEnclosingElement();

if (enclosingElement instanceof TypeElement typeElement) {
// Build the class name expression
JCTree.JCExpression classExpr = buildQualifiedName(typeElement);

// Build the method name
com.sun.tools.javac.util.Name methodName =
names.fromString(method.getSimpleName().toString());

// Create the field access (ClassName.methodName)
JCTree.JCFieldAccess fieldAccess = treeMaker.Select(classExpr, methodName);

// Set the symbol if method is a Symbol.MethodSymbol
if (method instanceof Symbol.MethodSymbol methodSymbol) {
fieldAccess.sym = methodSymbol;
fieldAccess.type = methodSymbol.type;
}

return fieldAccess;
} else {
throw new IllegalArgumentException(
"Method does not have a TypeElement as enclosing element: " + method);
}
}

/** Builds a qualified name expression with proper symbol and type information. */
private JCTree.JCExpression buildQualifiedName(TypeElement typeElement) {
String qualifiedName = typeElement.getQualifiedName().toString();
String[] parts = qualifiedName.split("\\.");

JCTree.JCExpression expr = treeMaker.Ident(names.fromString(parts[0]));

// Set symbol and type if typeElement is a Symbol.ClassSymbol
if (typeElement instanceof Symbol.ClassSymbol classSymbol) {
// For the final expression, set the class symbol
Symbol currentSym = classSymbol;

// Navigate through package symbols to build the path
for (int i = parts.length - 2; i >= 0; i--) {
if (currentSym.owner != null) {
currentSym = currentSym.owner;
}
}

// Build the expression with proper types
if (parts.length == 1) {
if (expr instanceof JCTree.JCIdent ident) {
ident.sym = classSymbol;
ident.type = classSymbol.type;
}
} else {
Symbol pkgSym = currentSym;
if (expr instanceof JCTree.JCIdent ident && pkgSym instanceof Symbol.PackageSymbol) {
ident.sym = pkgSym;
ident.type = pkgSym.type;
}

for (int i = 1; i < parts.length; i++) {
JCTree.JCFieldAccess select = treeMaker.Select(expr, names.fromString(parts[i]));
if (i == parts.length - 1) {
select.sym = classSymbol;
select.type = classSymbol.type;
}
expr = select;
}
}
} else {
// Fallback: just build the structure without symbols
for (int i = 1; i < parts.length; i++) {
expr = treeMaker.Select(expr, names.fromString(parts[i]));
}
}

return expr;
}
}
}
Expand Down
Loading