Java tutorial
/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationStateSerializerFactory; import com.facebook.presto.spi.function.CombineFunction; import com.facebook.presto.spi.function.InputFunction; import com.facebook.presto.spi.function.OutputFunction; import com.facebook.presto.spi.type.TypeSignature; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import javax.annotation.Nullable; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.Set; import static com.facebook.presto.operator.aggregation.AggregationImplementation.Parser.parseImplementation; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public class AggregationFromAnnotationsParser { private AggregationFromAnnotationsParser() { } // This function should only be used for function matching for testing purposes. // General purpose function matching is done through FunctionRegistry. @VisibleForTesting public static ParametricAggregation parseFunctionDefinitionWithTypesConstraint(Class<?> clazz, TypeSignature returnType, List<TypeSignature> argumentTypes) { requireNonNull(returnType, "returnType is null"); requireNonNull(argumentTypes, "argumentTypes is null"); for (ParametricAggregation aggregation : parseFunctionDefinitions(clazz)) { if (aggregation.getSignature().getReturnType().equals(returnType) && aggregation.getSignature().getArgumentTypes().equals(argumentTypes)) { return aggregation; } } throw new IllegalArgumentException( String.format("No method with return type %s and arguments %s", returnType, argumentTypes)); } public static List<ParametricAggregation> parseFunctionDefinitions(Class<?> aggregationDefinition) { AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class); requireNonNull(aggregationAnnotation, "aggregationAnnotation is null"); ImmutableList.Builder<ParametricAggregation> builder = ImmutableList.builder(); for (Class<?> stateClass : getStateClasses(aggregationDefinition)) { Method combineFunction = getCombineFunction(aggregationDefinition, stateClass); Optional<Method> aggregationStateSerializerFactory = getAggregationStateSerializerFactory( aggregationDefinition, stateClass); for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) { for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) { AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); ParametricImplementationsGroup<AggregationImplementation> implementations = ParametricImplementationsGroup .of(onlyImplementation); builder.add( new ParametricAggregation(implementations.getSignature(), header, implementations)); } } } } return builder.build(); } public static ParametricAggregation parseFunctionDefinition(Class<?> aggregationDefinition) { ParametricImplementationsGroup.Builder<AggregationImplementation> implementationsBuilder = ParametricImplementationsGroup .builder(); AggregationHeader header = parseHeader(aggregationDefinition); for (Class<?> stateClass : getStateClasses(aggregationDefinition)) { Method combineFunction = getCombineFunction(aggregationDefinition, stateClass); Optional<Method> aggregationStateSerializerFactory = getAggregationStateSerializerFactory( aggregationDefinition, stateClass); Method outputFunction = getOnlyElement(getOutputFunctions(aggregationDefinition, stateClass)); Method inputFunction = getOnlyElement(getInputFunctions(aggregationDefinition, stateClass)); AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); implementationsBuilder.addImplementation(implementation); } ParametricImplementationsGroup<AggregationImplementation> implementations = implementationsBuilder.build(); return new ParametricAggregation(implementations.getSignature(), header, implementations); } private static Optional<Method> getAggregationStateSerializerFactory(Class<?> aggregationDefinition, Class<?> stateClass) { // Only include methods that match this state class List<Method> stateSerializerFactories = FunctionsParserHelper .findPublicStaticMethodsWithAnnotation(aggregationDefinition, AggregationStateSerializerFactory.class) .stream() .filter(method -> ((AggregationStateSerializerFactory) method .getAnnotation(AggregationStateSerializerFactory.class)).value().equals(stateClass)) .collect(toImmutableList()); if (stateSerializerFactories.isEmpty()) { return Optional.empty(); } checkArgument(stateSerializerFactories.size() == 1, String.format( "Expect at most 1 @AggregationStateSerializerFactory(%s.class) annotation, found %s in %s", stateClass.toGenericString(), stateSerializerFactories.size(), aggregationDefinition.toGenericString())); return Optional.of(getOnlyElement(stateSerializerFactories)); } private static AggregationHeader parseHeader(AnnotatedElement aggregationDefinition) { AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class); requireNonNull(aggregationAnnotation, "aggregationAnnotation is null"); return new AggregationHeader(aggregationAnnotation.value(), parseDescription(aggregationDefinition), aggregationAnnotation.decomposable()); } private static List<AggregationHeader> parseHeaders(AnnotatedElement aggregationDefinition, AnnotatedElement toParse) { AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class); return getNames(toParse, aggregationAnnotation).stream().map(name -> new AggregationHeader(name, parseDescription(aggregationDefinition, toParse), aggregationAnnotation.decomposable())) .collect(toImmutableList()); } private static List<String> getNames(@Nullable AnnotatedElement outputFunction, AggregationFunction aggregationAnnotation) { List<String> defaultNames = ImmutableList.<String>builder().add(aggregationAnnotation.value()) .addAll(Arrays.asList(aggregationAnnotation.alias())).build(); if (outputFunction == null) { return defaultNames; } AggregationFunction annotation = outputFunction.getAnnotation(AggregationFunction.class); if (annotation == null) { return defaultNames; } else { return ImmutableList.<String>builder().add(annotation.value()).addAll(Arrays.asList(annotation.alias())) .build(); } } public static Method getCombineFunction(Class<?> clazz, Class<?> stateClass) { // Only include methods that match this state class List<Method> combineFunctions = FunctionsParserHelper .findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class).stream() .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser .findAggregationStateParamId(method, 0)] == stateClass) .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser .findAggregationStateParamId(method, 1)] == stateClass) .collect(toImmutableList()); checkArgument(combineFunctions.size() == 1, String.format( "There must be exactly one @CombineFunction in class %s for the @AggregationState %s ", clazz.toGenericString(), stateClass.toGenericString())); return getOnlyElement(combineFunctions); } private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass) { // Only include methods that match this state class List<Method> outputFunctions = FunctionsParserHelper .findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class).stream() .filter(method -> method.getParameterTypes()[AggregationImplementation.Parser .findAggregationStateParamId(method)] == stateClass) .collect(toImmutableList()); checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions"); return outputFunctions; } private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass) { // Only include methods that match this state class List<Method> inputFunctions = FunctionsParserHelper .findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class).stream() .filter(method -> (method.getParameterTypes()[AggregationImplementation.Parser .findAggregationStateParamId(method)] == stateClass)) .collect(toImmutableList()); checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions"); return inputFunctions; } private static Set<Class<?>> getStateClasses(Class<?> clazz) { ImmutableSet.Builder<Class<?>> builder = ImmutableSet.builder(); for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) { checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters"); Class<?> stateClass = AggregationImplementation.Parser.findAggregationStateParamType(inputFunction); checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState"); builder.add(stateClass); } ImmutableSet<Class<?>> stateClasses = builder.build(); checkArgument(!stateClasses.isEmpty(), "No input functions found"); return stateClasses; } }