com.linecorp.armeria.server.thrift.ThriftDocServicePlugin.java Source code

Java tutorial

Introduction

Here is the source code for com.linecorp.armeria.server.thrift.ThriftDocServicePlugin.java

Source

/*
 *  Copyright 2017 LINE Corporation
 *
 *  LINE Corporation licenses this file to you under the Apache License,
 *  version 2.0 (the "License"); you may not use this file except in compliance
 *  with the License. You may obtain a copy of the License at:
 *
 *    https://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  License for the specific language governing permissions and limitations
 *  under the License.
 */

package com.linecorp.armeria.server.thrift;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.TFieldRequirementType;
import org.apache.thrift.TSerializer;
import org.apache.thrift.meta_data.EnumMetaData;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.meta_data.FieldValueMetaData;
import org.apache.thrift.meta_data.ListMetaData;
import org.apache.thrift.meta_data.MapMetaData;
import org.apache.thrift.meta_data.SetMetaData;
import org.apache.thrift.meta_data.StructMetaData;
import org.apache.thrift.protocol.TType;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.server.PathMapping;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.docs.DocServicePlugin;
import com.linecorp.armeria.server.docs.EndpointInfo;
import com.linecorp.armeria.server.docs.EnumInfo;
import com.linecorp.armeria.server.docs.EnumValueInfo;
import com.linecorp.armeria.server.docs.ExceptionInfo;
import com.linecorp.armeria.server.docs.FieldInfo;
import com.linecorp.armeria.server.docs.FieldRequirement;
import com.linecorp.armeria.server.docs.MethodInfo;
import com.linecorp.armeria.server.docs.NamedTypeInfo;
import com.linecorp.armeria.server.docs.ServiceInfo;
import com.linecorp.armeria.server.docs.ServiceSpecification;
import com.linecorp.armeria.server.docs.StructInfo;
import com.linecorp.armeria.server.docs.TypeSignature;

/**
 * {@link DocServicePlugin} implementation that supports {@link THttpService}s.
 */
public class ThriftDocServicePlugin implements DocServicePlugin {

    private static final String REQUEST_STRUCT_SUFFIX = "_args";

    private static final TypeSignature VOID = TypeSignature.ofBase("void");
    private static final TypeSignature BOOL = TypeSignature.ofBase("bool");
    private static final TypeSignature I8 = TypeSignature.ofBase("i8");
    private static final TypeSignature I16 = TypeSignature.ofBase("i16");
    private static final TypeSignature I32 = TypeSignature.ofBase("i32");
    private static final TypeSignature I64 = TypeSignature.ofBase("i64");
    private static final TypeSignature DOUBLE = TypeSignature.ofBase("double");
    private static final TypeSignature STRING = TypeSignature.ofBase("string");
    private static final TypeSignature BINARY = TypeSignature.ofBase("binary");

    private final ThriftDocStringExtractor docstringExtractor = new ThriftDocStringExtractor();

    // Methods related with generating a service specification.

    @Override
    public Set<Class<? extends Service<?, ?>>> supportedServiceTypes() {
        return ImmutableSet.of(THttpService.class);
    }

    @Override
    public ServiceSpecification generateSpecification(Set<ServiceConfig> serviceConfigs) {

        final Map<Class<?>, EntryBuilder> map = new LinkedHashMap<>();

        for (ServiceConfig c : serviceConfigs) {
            final THttpService service = c.service().as(THttpService.class).get();
            service.entries().forEach((serviceName, entry) -> {
                for (Class<?> iface : entry.interfaces()) {
                    final Class<?> serviceClass = iface.getEnclosingClass();
                    final EntryBuilder builder = map.computeIfAbsent(serviceClass,
                            cls -> new EntryBuilder(serviceClass));

                    // Add all available endpoints. Accept only the services with exact and prefix path
                    // mappings, whose endpoint path can be determined.
                    final PathMapping pathMapping = c.pathMapping();
                    final String path = pathMapping.exactPath().orElse(pathMapping.prefix().orElse(null));
                    if (path != null) {
                        builder.endpoint(new EndpointInfo(c.virtualHost().hostnamePattern(), path, serviceName,
                                service.defaultSerializationFormat(), service.allowedSerializationFormats()));
                    }
                }
            });
        }

        final List<Entry> entries = map.values().stream().map(EntryBuilder::build).collect(Collectors.toList());
        return generate(entries);
    }

    @VisibleForTesting
    static ServiceSpecification generate(List<Entry> entries) {
        final List<ServiceInfo> services = entries.stream().map(e -> newServiceInfo(e.serviceType, e.endpointInfos))
                .collect(toImmutableList());

        return ServiceSpecification.generate(services, ThriftDocServicePlugin::newNamedTypeInfo);
    }

    @VisibleForTesting
    static ServiceInfo newServiceInfo(Class<?> serviceClass, Iterable<EndpointInfo> endpoints) {
        requireNonNull(serviceClass, "serviceClass");

        final String name = serviceClass.getName();
        final ClassLoader serviceClassLoader = serviceClass.getClassLoader();
        final String interfaceClassName = name + "$Iface";
        final Class<?> interfaceClass;
        try {
            interfaceClass = Class.forName(interfaceClassName, false, serviceClassLoader);
        } catch (ClassNotFoundException e) {
            throw new IllegalStateException("failed to find a class: " + interfaceClassName, e);
        }
        final Method[] methods = interfaceClass.getDeclaredMethods();

        return new ServiceInfo(name, Arrays.stream(methods).map(m -> newMethodInfo(m, endpoints))::iterator);
    }

    private static MethodInfo newMethodInfo(Method method, Iterable<EndpointInfo> endpoints) {
        requireNonNull(method, "method");

        final String methodName = method.getName();

        final Class<?> serviceClass = method.getDeclaringClass().getDeclaringClass();
        final String serviceName = serviceClass.getName();
        final ClassLoader classLoader = serviceClass.getClassLoader();

        final String argsClassName = serviceName + '$' + methodName + "_args";
        final Class<? extends TBase<?, ?>> argsClass;
        try {
            @SuppressWarnings("unchecked")
            final Class<? extends TBase<?, ?>> argsClass0 = (Class<? extends TBase<?, ?>>) Class
                    .forName(argsClassName, false, classLoader);
            argsClass = argsClass0;
        } catch (ClassNotFoundException e) {
            throw new IllegalStateException("failed to find a class: " + argsClassName, e);
        }

        Class<?> resultClass;
        try {
            resultClass = Class.forName(serviceName + '$' + methodName + "_result", false, classLoader);
        } catch (ClassNotFoundException ignored) {
            // Oneway function does not have a result type.
            resultClass = null;
        }

        @SuppressWarnings("unchecked")
        final MethodInfo methodInfo = newMethodInfo(methodName, argsClass,
                (Class<? extends TBase<?, ?>>) resultClass,
                (Class<? extends TException>[]) method.getExceptionTypes(), endpoints);
        return methodInfo;
    }

    private static MethodInfo newMethodInfo(String name, Class<? extends TBase<?, ?>> argsClass,
            @Nullable Class<? extends TBase<?, ?>> resultClass, Class<? extends TException>[] exceptionClasses,
            Iterable<EndpointInfo> endpoints) {
        requireNonNull(name, "name");
        requireNonNull(argsClass, "argsClass");
        requireNonNull(exceptionClasses, "exceptionClasses");
        requireNonNull(endpoints, "endpoints");

        final List<FieldInfo> parameters = FieldMetaData.getStructMetaDataMap(argsClass).values().stream()
                .map(fieldMetaData -> newFieldInfo(argsClass, fieldMetaData)).collect(toImmutableList());

        // Find the 'success' field.
        FieldInfo fieldInfo = null;
        if (resultClass != null) { // Function isn't "oneway" function
            final Map<? extends TFieldIdEnum, FieldMetaData> resultMetaData = FieldMetaData
                    .getStructMetaDataMap(resultClass);

            for (FieldMetaData fieldMetaData : resultMetaData.values()) {
                if ("success".equals(fieldMetaData.fieldName)) {
                    fieldInfo = newFieldInfo(resultClass, fieldMetaData);
                    break;
                }
            }
        }

        final TypeSignature returnTypeSignature;
        if (fieldInfo == null) {
            returnTypeSignature = VOID;
        } else {
            returnTypeSignature = fieldInfo.typeSignature();
        }

        final List<TypeSignature> exceptionTypeSignatures = Arrays.stream(exceptionClasses)
                .filter(e -> e != TException.class).map(TypeSignature::ofNamed).collect(toImmutableList());

        return new MethodInfo(name, returnTypeSignature, parameters, exceptionTypeSignatures, endpoints);
    }

    private static NamedTypeInfo newNamedTypeInfo(TypeSignature typeSignature) {
        final Class<?> type = (Class<?>) typeSignature.namedTypeDescriptor().get();
        if (type.isEnum()) {
            return newEnumInfo(type);
        }

        if (TException.class.isAssignableFrom(type)) {
            @SuppressWarnings("unchecked")
            final Class<? extends TException> castType = (Class<? extends TException>) type;
            return newExceptionInfo(castType);
        }

        assert TBase.class.isAssignableFrom(type);
        @SuppressWarnings("unchecked")
        final Class<? extends TBase<?, ?>> castType = (Class<? extends TBase<?, ?>>) type;
        return newStructInfo(castType);
    }

    @VisibleForTesting
    static EnumInfo newEnumInfo(Class<?> enumClass) {
        requireNonNull(enumClass, "enumClass");

        final List<EnumValueInfo> values = new ArrayList<>();
        final Field[] fields = enumClass.getDeclaredFields();
        for (Field field : fields) {
            if (field.isEnumConstant()) {
                try {
                    values.add(new EnumValueInfo(String.valueOf(field.get(null))));
                } catch (IllegalAccessException ignored) {
                    // Skip inaccessible fields.
                }
            }
        }

        final String name = enumClass.getName();
        return new EnumInfo(name, values);
    }

    @VisibleForTesting
    static StructInfo newStructInfo(Class<? extends TBase<?, ?>> structClass) {
        final String name = structClass.getName();

        final Map<?, FieldMetaData> metaDataMap = FieldMetaData.getStructMetaDataMap(structClass);
        final List<FieldInfo> fields = metaDataMap.values().stream()
                .map(fieldMetaData -> newFieldInfo(structClass, fieldMetaData)).collect(Collectors.toList());

        return new StructInfo(name, fields);
    }

    @VisibleForTesting
    static ExceptionInfo newExceptionInfo(Class<? extends TException> exceptionClass) {
        requireNonNull(exceptionClass, "exceptionClass");
        final String name = exceptionClass.getName();

        List<FieldInfo> fields;
        try {
            @SuppressWarnings("unchecked")
            final Map<?, FieldMetaData> metaDataMap = (Map<?, FieldMetaData>) exceptionClass
                    .getDeclaredField("metaDataMap").get(null);

            fields = metaDataMap.values().stream().map(fieldMetaData -> newFieldInfo(exceptionClass, fieldMetaData))
                    .collect(toImmutableList());
        } catch (IllegalAccessException e) {
            throw new AssertionError("will not happen", e);
        } catch (NoSuchFieldException ignored) {
            fields = Collections.emptyList();
        }

        return new ExceptionInfo(name, fields);
    }

    @VisibleForTesting
    static FieldInfo newFieldInfo(Class<?> parentType, FieldMetaData fieldMetaData) {
        requireNonNull(fieldMetaData, "fieldMetaData");
        final FieldValueMetaData fieldValueMetaData = fieldMetaData.valueMetaData;
        final TypeSignature typeSignature;

        if (fieldValueMetaData.isStruct() && fieldValueMetaData.isTypedef()
                && parentType.getSimpleName().equals(fieldValueMetaData.getTypedefName())) {
            // Handle the special case where a struct field refers to itself,
            // where the Thrift compiler handles it as a typedef.
            typeSignature = TypeSignature.ofNamed(parentType);
        } else {
            typeSignature = toTypeSignature(fieldValueMetaData);
        }

        return new FieldInfo(fieldMetaData.fieldName, convertRequirement(fieldMetaData.requirementType),
                typeSignature);
    }

    @VisibleForTesting
    static TypeSignature toTypeSignature(FieldValueMetaData fieldValueMetaData) {
        if (fieldValueMetaData instanceof StructMetaData) {
            return TypeSignature.ofNamed(((StructMetaData) fieldValueMetaData).structClass);
        }

        if (fieldValueMetaData instanceof EnumMetaData) {
            return TypeSignature.ofNamed(((EnumMetaData) fieldValueMetaData).enumClass);
        }

        if (fieldValueMetaData instanceof ListMetaData) {
            return TypeSignature.ofList(toTypeSignature(((ListMetaData) fieldValueMetaData).elemMetaData));
        }

        if (fieldValueMetaData instanceof SetMetaData) {
            return TypeSignature.ofSet(toTypeSignature(((SetMetaData) fieldValueMetaData).elemMetaData));
        }

        if (fieldValueMetaData instanceof MapMetaData) {
            return TypeSignature.ofMap(toTypeSignature(((MapMetaData) fieldValueMetaData).keyMetaData),
                    toTypeSignature(((MapMetaData) fieldValueMetaData).valueMetaData));
        }

        if (fieldValueMetaData.isBinary()) {
            return BINARY;
        }

        switch (fieldValueMetaData.type) {
        case TType.VOID:
            return VOID;
        case TType.BOOL:
            return BOOL;
        case TType.BYTE:
            return I8;
        case TType.DOUBLE:
            return DOUBLE;
        case TType.I16:
            return I16;
        case TType.I32:
            return I32;
        case TType.I64:
            return I64;
        case TType.STRING:
            return STRING;
        }

        final String unresolvedName;
        if (fieldValueMetaData.isTypedef()) {
            unresolvedName = fieldValueMetaData.getTypedefName();
        } else {
            unresolvedName = null;
        }

        return TypeSignature.ofUnresolved(firstNonNull(unresolvedName, "unknown"));
    }

    private static FieldRequirement convertRequirement(byte value) {
        switch (value) {
        case TFieldRequirementType.REQUIRED:
            return FieldRequirement.REQUIRED;
        case TFieldRequirementType.OPTIONAL:
            return FieldRequirement.OPTIONAL;
        case TFieldRequirementType.DEFAULT:
            return FieldRequirement.DEFAULT;
        default:
            throw new IllegalArgumentException("unknown requirement type: " + value);
        }
    }

    @VisibleForTesting
    static final class Entry {
        final Class<?> serviceType;
        final List<EndpointInfo> endpointInfos;

        Entry(Class<?> serviceType, List<EndpointInfo> endpointInfos) {
            this.serviceType = serviceType;
            this.endpointInfos = ImmutableList.copyOf(endpointInfos);
        }
    }

    @VisibleForTesting
    static final class EntryBuilder {
        private final Class<?> serviceType;
        private final List<EndpointInfo> endpointInfos = new ArrayList<>();

        EntryBuilder(Class<?> serviceType) {
            this.serviceType = requireNonNull(serviceType, "serviceType");
        }

        EntryBuilder endpoint(EndpointInfo endpointInfo) {
            endpointInfos.add(requireNonNull(endpointInfo, "endpointInfo"));
            return this;
        }

        Entry build() {
            return new Entry(serviceType, endpointInfos);
        }
    }

    // Methods related with extracting documentation strings.

    @Override
    public Map<String, String> loadDocStrings(Set<ServiceConfig> serviceConfigs) {
        return serviceConfigs.stream()
                .flatMap(c -> c.service().as(THttpService.class).get().entries().values().stream())
                .flatMap(entry -> entry.interfaces().stream().map(Class::getClassLoader))
                .flatMap(loader -> docstringExtractor.getAllDocStrings(loader).entrySet().stream())
                .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> a));
    }

    // Methods related with serializing example requests.

    @Override
    public Set<Class<?>> supportedExampleRequestTypes() {
        return ImmutableSet.of(TBase.class);
    }

    @Override
    public Optional<String> guessServiceName(Object exampleRequest) {
        final TBase<?, ?> exampleTBase = asTBase(exampleRequest);
        if (exampleTBase == null) {
            return Optional.empty();
        }

        return Optional.of(exampleTBase.getClass().getEnclosingClass().getName());
    }

    @Override
    public Optional<String> guessServiceMethodName(Object exampleRequest) {
        final TBase<?, ?> exampleTBase = asTBase(exampleRequest);
        if (exampleTBase == null) {
            return Optional.empty();
        }

        final String typeName = exampleTBase.getClass().getName();
        return Optional.of(typeName.substring(typeName.lastIndexOf('$') + 1,
                typeName.length() - REQUEST_STRUCT_SUFFIX.length()));
    }

    @Override
    public Optional<String> serializeExampleRequest(String serviceName, String methodName, Object exampleRequest) {
        if (!(exampleRequest instanceof TBase)) {
            return Optional.empty();
        }

        final TBase<?, ?> exampleTBase = (TBase<?, ?>) exampleRequest;
        final TSerializer serializer = new TSerializer(ThriftProtocolFactories.TEXT);
        try {
            return Optional.of(serializer.toString(exampleTBase, StandardCharsets.UTF_8.name()));
        } catch (TException e) {
            throw new Error("should never reach here", e);
        }
    }

    @Nullable
    private static TBase<?, ?> asTBase(Object exampleRequest) {
        final TBase<?, ?> exampleTBase = (TBase<?, ?>) exampleRequest;
        final Class<?> type = exampleTBase.getClass();
        if (!type.getName().endsWith(REQUEST_STRUCT_SUFFIX)) {
            return null;
        }

        final Class<?> serviceType = type.getEnclosingClass();
        if (serviceType == null) {
            return null;
        }

        if (serviceType.getEnclosingClass() != null) {
            return null;
        }

        return exampleTBase;
    }
}