org.diqube.ui.websocket.request.JsonRequestDeserializer.java Source code

Java tutorial

Introduction

Here is the source code for org.diqube.ui.websocket.request.JsonRequestDeserializer.java

Source

/**
 * diqube: Distributed Query Base.
 *
 * Copyright (C) 2015 Bastian Gloeckle
 *
 * This file is part of diqube.
 *
 * diqube is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.diqube.ui.websocket.request;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import javax.annotation.PostConstruct;
import javax.inject.Inject;
import javax.websocket.Session;

import org.diqube.context.AutoInstatiate;
import org.diqube.thrift.base.thrift.AuthenticationException;
import org.diqube.thrift.base.thrift.Ticket;
import org.diqube.ticket.TicketUtil;
import org.diqube.ticket.TicketValidityService;
import org.diqube.ui.ticket.TicketsAcceptableProvider;
import org.diqube.ui.websocket.request.commands.CommandInformation;
import org.diqube.ui.websocket.request.commands.JsonCommand;
import org.diqube.ui.websocket.result.JsonResult;
import org.diqube.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.util.ReflectionUtils;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.BaseEncoding;
import com.google.common.reflect.ClassPath;
import com.google.common.reflect.ClassPath.ClassInfo;

/**
 * Deserializes arbitrary {@link JsonRequest} objects, including their {@link JsonCommand}s.
 * 
 * <p>
 * In addition to deserializing the {@link JsonResult} from JSON and the {@link JsonCommand} from JSON, the
 * {@link JsonCommand} may contain additional fields that have both annotations: {@link Inject} and {@link JsonIgnore}.
 * These fields are then wired to instances of matching types from the bean context.
 * 
 * <p>
 * If the command has a method annotated with {@link PostConstruct}, it will be called accordingly.
 * 
 * <p>
 * This class will validate the {@link Ticket} if one is provided and provide a corresponding {@link JsonCommand}.
 *
 * @author Bastian Gloeckle
 */
@AutoInstatiate
public class JsonRequestDeserializer {
    private static final Logger logger = LoggerFactory.getLogger(JsonRequestDeserializer.class);
    private static final String JSON_REQUEST_ID = "requestId";
    private static final String JSON_COMMAND_NAME = "command";
    private static final String JSON_COMMAND_DATA = "commandData";
    private static final String JSON_TICKET = "ticket";

    private Map<String, Class<? extends JsonCommand>> commandClasses;

    @Inject
    private ApplicationContext beanContext;

    @Inject
    private JsonRequestRegistry jsonRequestRegistry;

    @Inject
    private TicketValidityService ticketValidityService;

    @Inject
    private TicketsAcceptableProvider ticketsAcceptableProvider;

    private JsonFactory jsonFactory = new JsonFactory();
    private ObjectMapper mapper = new ObjectMapper(jsonFactory);

    /**
     * Deserialize an arbitrary {@link JsonRequest}.
     * 
     * <p>
     * This method validates the {@link Ticket} if one is provided in the JSON. If the ticket is invalid, the returned
     * {@link JsonRequest} will have a command that simply fails.
     * 
     * @param clientJson
     *          The JSON provided by the client which contains information about the request to be built.
     * @param websocketSession
     *          the {@link Session} to which the request belongs.
     * @return the new object
     * @throws JsonPayloadDeserializerException
     *           if anything went wrong.
     */
    public JsonRequest deserialize(String clientJson, Session websocketSession)
            throws JsonPayloadDeserializerException {
        try {
            JsonNode requestTreeRoot = mapper.readTree(clientJson);
            String requestId = requestTreeRoot.get(JSON_REQUEST_ID).textValue();
            String commandName = requestTreeRoot.get(JSON_COMMAND_NAME).textValue();

            if (!commandClasses.containsKey(commandName))
                throw new JsonPayloadDeserializerException("Unknown command: " + commandName);

            Class<? extends JsonCommand> cmdClass = commandClasses.get(commandName);

            JsonCommand cmd;
            if (requestTreeRoot.get(JSON_COMMAND_DATA) != null && !requestTreeRoot.get(JSON_COMMAND_DATA).isNull())
                cmd = mapper.treeToValue(requestTreeRoot.get(JSON_COMMAND_DATA), cmdClass);
            else
                try {
                    cmd = cmdClass.newInstance();
                } catch (InstantiationException | IllegalAccessException e) {
                    logger.error("Could not instantiate command class", e);
                    throw new JsonPayloadDeserializerException("Could not instantiate command class");
                }

            Ticket t = null;
            if (requestTreeRoot.get(JSON_TICKET) != null) {
                if (!ticketsAcceptableProvider.areTicketsAcceptable()) {
                    // currently we do not accept any ticket, because we were not able to reach diqube-servers recently.
                    cmd = new JsonCommand() {
                        @Override
                        public void execute(Ticket ticket, CommandResultHandler resultHandler,
                                CommandClusterInteraction clusterInteraction)
                                throws RuntimeException, AuthenticationException {
                            throw new AuthenticationException(
                                    "UI server does not accept any tickets currently because it was "
                                            + "not able to reach any diqube server. Retry shortly.");
                        }
                    };
                } else {
                    String ticketBase64 = requestTreeRoot.get(JSON_TICKET).textValue();
                    Pair<Ticket, byte[]> p = null;
                    boolean ticketDeserializationException = false;
                    try {
                        byte[] ticketSerialized = BaseEncoding.base64().decode(ticketBase64);
                        p = TicketUtil.deserialize(ByteBuffer.wrap(ticketSerialized));
                    } catch (IllegalArgumentException e) {
                        ticketDeserializationException = true;
                    }
                    // ensure that ticket is valid. If not, reject request directly.
                    if (ticketDeserializationException || !ticketValidityService.isTicketValid(p)) {
                        logger.warn("Invalid ticket provided by client.");
                        cmd = new JsonCommand() {
                            @Override
                            public void execute(Ticket ticket, CommandResultHandler resultHandler,
                                    CommandClusterInteraction clusterInteraction)
                                    throws RuntimeException, AuthenticationException {
                                throw new AuthenticationException("Ticket invalid.");
                            }
                        };
                    } else
                        t = p.getLeft();
                }
            }

            wireInjectFieldsAndCallPostConstruct(cmd);
            JsonRequest request = new JsonRequest(websocketSession, t, requestId, cmd, jsonRequestRegistry);
            wireInjectFieldsAndCallPostConstruct(request);

            return request;
        } catch (Exception e) {
            throw new JsonPayloadDeserializerException("Invalid JSON", e);
        }
    }

    private void wireInjectFieldsAndCallPostConstruct(Object o) throws JsonPayloadDeserializerException {
        Class<?> objectClass = o.getClass();
        Class<?> curClass = o.getClass();
        while (!curClass.equals(Object.class)) {
            for (Field f : curClass.getDeclaredFields()) {
                Inject[] injects = f.getAnnotationsByType(Inject.class);
                JsonIgnore[] jsonIgnores = f.getAnnotationsByType(JsonIgnore.class);
                if (injects.length > 0 && jsonIgnores.length > 0) {
                    try {
                        Object value = beanContext.getBean(f.getType());
                        ReflectionUtils.makeAccessible(f);
                        try {
                            f.set(o, value);
                            logger.trace("Wired object to {}#{}", objectClass.getName(), f.getName());
                        } catch (IllegalArgumentException | IllegalAccessException e) {
                            logger.debug("Could not wire object to {}#{}", objectClass, f.getName(), e);
                        }
                    } catch (NoSuchBeanDefinitionException e) {
                        logger.debug("Not wiring object to {}#{} because no corresponding bean available",
                                objectClass, f.getName());
                    }
                }
            }

            for (Method m : curClass.getMethods()) {
                if (Modifier.isPublic(m.getModifiers()) && m.isAnnotationPresent(PostConstruct.class)) {
                    if (m.getParameterCount() == 0 && m.getReturnType().equals(Void.TYPE)) {
                        try {
                            m.invoke(o);
                        } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
                            throw new JsonPayloadDeserializerException("Could not invoke PostConstruct.", e);
                        }
                    }
                }
            }

            curClass = curClass.getSuperclass();
        }
    }

    @SuppressWarnings("unchecked")
    @PostConstruct
    public void initialize() throws IOException {
        commandClasses = new HashMap<>();
        Set<ClassInfo> classInfos = ClassPath.from(getClass().getClassLoader())
                .getTopLevelClassesRecursive("org.diqube.ui");
        for (ClassInfo classInfo : classInfos) {
            Class<?> clazz = classInfo.load();
            if (clazz.isAnnotationPresent(CommandInformation.class) && JsonCommand.class.isAssignableFrom(clazz)) {
                CommandInformation annotation = clazz.getAnnotation(CommandInformation.class);

                commandClasses.put(annotation.name(), (Class<? extends JsonCommand>) clazz);
            }
        }
    }

    public static class JsonPayloadDeserializerException extends Exception {
        private static final long serialVersionUID = 1L;

        public JsonPayloadDeserializerException(String msg) {
            super(msg);
        }

        public JsonPayloadDeserializerException(String msg, Throwable cause) {
            super(msg, cause);
        }
    }
}