diff --git a/src/main/java/org/springframework/retry/annotation/RecoverAnnotationRecoveryHandler.java b/src/main/java/org/springframework/retry/annotation/RecoverAnnotationRecoveryHandler.java index d798a863..d3a21dfd 100644 --- a/src/main/java/org/springframework/retry/annotation/RecoverAnnotationRecoveryHandler.java +++ b/src/main/java/org/springframework/retry/annotation/RecoverAnnotationRecoveryHandler.java @@ -19,8 +19,13 @@ import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; -import java.util.HashMap; +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.LinkedHashMap; import java.util.Map; +import java.util.stream.Collectors; import org.springframework.classify.SubclassClassifier; import org.springframework.core.annotation.AnnotatedElementUtils; @@ -55,12 +60,13 @@ * @author Gianluca Medici * @author Lijinliang * @author Yanming Zhou + * @author Chih-Yu Huang */ public class RecoverAnnotationRecoveryHandler implements MethodInvocationRecoverer { private final SubclassClassifier classifier = new SubclassClassifier<>(); - private final Map methods = new HashMap<>(); + private final Map methods = new LinkedHashMap<>(); private final Object target; @@ -73,7 +79,8 @@ public RecoverAnnotationRecoveryHandler(Object target, Method method) { @Override public T recover(Object[] args, Throwable cause) { - Method method = findClosestMatch(args, cause.getClass()); + Class causeType = (cause == null) ? null : cause.getClass(); + Method method = findClosestMatch(args, causeType); if (method == null) { throw new ExhaustedRetryException("Cannot locate recovery method", cause); } @@ -112,50 +119,91 @@ private Method findMethodOnProxy(Method method, Object proxy) { } private Method findClosestMatch(Object[] args, Class cause) { - Method result = null; + if (StringUtils.hasText(this.recoverMethodName)) { + return findMethodByName(args, cause); + } - if (!StringUtils.hasText(this.recoverMethodName)) { - int min = Integer.MAX_VALUE; - for (Map.Entry entry : this.methods.entrySet()) { - Method method = entry.getKey(); - SimpleMetadata meta = entry.getValue(); - Class type = meta.getType(); - if (type == null) { - type = Throwable.class; + List withThrowable = new ArrayList<>(); + List withoutThrowable = new ArrayList<>(); + for (Method method : this.methods.keySet()) { + SimpleMetadata meta = this.methods.get(method); + if (meta.getType() != null) { + withThrowable.add(method); + } + else { + withoutThrowable.add(method); + } + } + + Method result = findMethodWithThrowable(args, cause, withThrowable); + if (result == null) { + result = findMethodWithNoThrowable(args, withoutThrowable); + } + return result; + } + + private static Method findMethodWithNoThrowable(Object[] args, List methods) { + Method result = null; + for (Method method : methods) { + if (compareParameters(args, method.getParameterTypes(), false)) { + if (result == null || result.getParameterCount() < method.getParameterCount()) { + result = method; } - if (type.isAssignableFrom(cause)) { - int distance = calculateDistance(cause, type); - if (distance < min) { - min = distance; - result = method; + } + } + return result; + } + + private Method findMethodWithThrowable(Object[] args, Class cause, List methods) { + Method result = null; + int minDistance = Integer.MAX_VALUE; + List candidates = new ArrayList<>(); + + if (cause != null) { + for (Method method : methods) { + SimpleMetadata meta = this.methods.get(method); + Class exceptionType = meta.getType(); + if (exceptionType.isAssignableFrom(cause)) { + int distance = calculateDistance(cause, exceptionType); + if (distance < minDistance) { + minDistance = distance; + candidates.clear(); + candidates.add(method); } - else if (distance == min) { - boolean parametersMatch = compareParameters(args, meta.getArgCount(), - method.getParameterTypes(), false); - if (parametersMatch) { - result = method; - } + else if (distance == minDistance) { + candidates.add(method); } } } } - else { - for (Map.Entry entry : this.methods.entrySet()) { - Method method = entry.getKey(); - if (method.getName().equals(this.recoverMethodName)) { - SimpleMetadata meta = entry.getValue(); - if ((meta.type == null || meta.type.isAssignableFrom(cause)) - && compareParameters(args, meta.getArgCount(), method.getParameterTypes(), true)) { - result = method; - break; - } + + for (Method method : candidates) { + if (compareParameters(args, method.getParameterTypes(), true)) { + if (result == null || result.getParameterCount() < method.getParameterCount()) { + result = method; } } } return result; } - private int calculateDistance(Class cause, Class type) { + private Method findMethodByName(Object[] args, Class cause) { + for (Map.Entry entry : this.methods.entrySet()) { + Method method = entry.getKey(); + if (method.getName().equals(this.recoverMethodName)) { + SimpleMetadata meta = entry.getValue(); + Class exceptionType = meta.getType(); + if (exceptionType == null || (cause != null && exceptionType.isAssignableFrom(cause))) { + if (compareParameters(args, method.getParameterTypes(), exceptionType != null)) { + return method; + } + } + } + } + return null; + } + + private static int calculateDistance(Class cause, Class type) { int result = 0; Class current = cause; while (current != type && current != Throwable.class) { @@ -165,53 +213,59 @@ private int calculateDistance(Class cause, Class[] parameterTypes, - boolean withRecoverMethodName) { - if ((withRecoverMethodName && argCount == args.length) || argCount == (args.length + 1)) { - int startingIndex = 0; - if (parameterTypes.length > 0 && Throwable.class.isAssignableFrom(parameterTypes[0])) { - startingIndex = 1; + private static boolean compareParameters(Object[] args, Class[] parameterTypes, boolean hasThrowable) { + int argCount = args.length; + int paramCount = parameterTypes.length; + int argIndex = 0; + int paramIndex = hasThrowable ? 1 : 0; + + while (paramIndex < paramCount) { + Class parameterType = parameterTypes[paramIndex]; + Object argument = (argIndex < argCount) ? args[argIndex] : null; + + if (argument == null && parameterType.isPrimitive()) { + return false; } - for (int i = startingIndex; i < parameterTypes.length; i++) { - final Object argument = i - startingIndex < args.length ? args[i - startingIndex] : null; - if (argument == null) { - continue; - } - Class parameterType = parameterTypes[i]; - parameterType = ClassUtils.resolvePrimitiveIfNecessary(parameterType); - if (!parameterType.isAssignableFrom(argument.getClass())) { - return false; - } + if (argument != null && !ClassUtils.isAssignable(parameterType, argument.getClass())) { + return false; } - return true; + paramIndex++; + argIndex++; } - return false; + return true; } private void init(final Object target, Method method) { - final Map, Method> types = new HashMap<>(); + final Map, Method> types = new LinkedHashMap<>(); final Method failingMethod = method; Retryable retryable = AnnotatedElementUtils.findMergedAnnotation(method, Retryable.class); if (retryable != null) { this.recoverMethodName = retryable.recover(); } - ReflectionUtils.doWithMethods(target.getClass(), candidate -> { + Method[] declared = target.getClass().getDeclaredMethods(); + Arrays.sort(declared, Comparator.comparing(Method::getName) + .thenComparingInt(Method::getParameterCount) + .thenComparing( + m -> Arrays.stream(m.getParameterTypes()).map(Class::getName).collect(Collectors.joining(",")))); + + for (Method candidate : declared) { Recover recover = AnnotatedElementUtils.findMergedAnnotation(candidate, Recover.class); if (recover == null) { recover = findAnnotationOnTarget(target, candidate); } - if (recover != null && failingMethod.getGenericReturnType() instanceof ParameterizedType - && candidate.getGenericReturnType() instanceof ParameterizedType) { - if (isParameterizedTypeAssignable((ParameterizedType) candidate.getGenericReturnType(), - (ParameterizedType) failingMethod.getGenericReturnType())) { + if (recover != null) { + if (failingMethod.getGenericReturnType() instanceof ParameterizedType + && candidate.getGenericReturnType() instanceof ParameterizedType) { + if (isParameterizedTypeAssignable((ParameterizedType) candidate.getGenericReturnType(), + (ParameterizedType) failingMethod.getGenericReturnType())) { + putToMethodsMap(candidate, types); + } + } + else if (candidate.getReturnType().isAssignableFrom(failingMethod.getReturnType())) { putToMethodsMap(candidate, types); } } - else if (recover != null && candidate.getReturnType().isAssignableFrom(failingMethod.getReturnType())) { - putToMethodsMap(candidate, types); - } - }); - this.classifier.setTypeMap(types); + } optionallyFilterMethodsBy(failingMethod.getReturnType()); } @@ -261,11 +315,10 @@ private void putToMethodsMap(Method method, Map, Meth @SuppressWarnings("unchecked") Class type = (Class) parameterTypes[0]; types.put(type, method); - RecoverAnnotationRecoveryHandler.this.methods.put(method, new SimpleMetadata(parameterTypes.length, type)); + this.methods.put(method, new SimpleMetadata(parameterTypes.length, type)); } else { - RecoverAnnotationRecoveryHandler.this.classifier.setDefaultValue(method); - RecoverAnnotationRecoveryHandler.this.methods.put(method, new SimpleMetadata(parameterTypes.length, null)); + this.methods.put(method, new SimpleMetadata(parameterTypes.length, null)); } } @@ -280,7 +333,7 @@ private Recover findAnnotationOnTarget(Object target, Method method) { } private void optionallyFilterMethodsBy(Class returnClass) { - Map filteredMethods = new HashMap<>(); + Map filteredMethods = new LinkedHashMap<>(); for (Method method : this.methods.keySet()) { if (method.getReturnType() == returnClass) { filteredMethods.put(method, this.methods.get(method));