org.brutusin.rpc.websocket.WebsocketEndpoint.java Source code

Java tutorial

Introduction

Here is the source code for org.brutusin.rpc.websocket.WebsocketEndpoint.java

Source

/*
 * Copyright 2016 Ignacio del Valle Alles idelvall@brutusin.org.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.brutusin.rpc.websocket;

import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.Session;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import org.brutusin.commons.utils.Miscellaneous;
import org.brutusin.rpc.RpcResponse;
import org.brutusin.rpc.RpcRequest;
import org.brutusin.rpc.exception.ServiceNotFoundException;
import org.brutusin.json.spi.JsonCodec;
import org.brutusin.json.spi.JsonSchema;
import org.brutusin.rpc.RpcSpringContext;
import org.brutusin.rpc.exception.ErrorFactory;
import org.brutusin.rpc.exception.InvalidRequestException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.security.core.context.SecurityContext;

/**
 *
 * @author Ignacio del Valle Alles idelvall@brutusin.org
 */
public class WebsocketEndpoint extends Endpoint {

    private final Map<String, WebsocketContext> contextMap = Collections.synchronizedMap(new HashMap());
    private final Map<String, SessionImpl> wrapperMap = Collections.synchronizedMap(new HashMap());

    /**
     *
     * @param session
     * @param config
     */
    @Override
    public void onOpen(Session session, EndpointConfig config) {
        final WebsocketContext websocketContext = contextMap
                .get(session.getRequestParameterMap().get("requestId").get(0));
        if (!allowAccess(session, websocketContext)) {
            try {
                session.close(new CloseReason(CloseReason.CloseCodes.CANNOT_ACCEPT, "Authentication required"));
            } catch (IOException ex) {
                throw new RuntimeException(ex);
            }
            return;
        }
        final SessionImpl sessionImpl = new SessionImpl(session, websocketContext);
        sessionImpl.init();
        wrapperMap.put(session.getId(), sessionImpl);

        session.addMessageHandler(new MessageHandler.Whole<String>() {
            public void onMessage(String message) {
                WebsocketActionSupportImpl.setInstance(new WebsocketActionSupportImpl(sessionImpl));
                try {
                    String response = process(message, sessionImpl);
                    if (response != null) {
                        sessionImpl.sendToPeerRaw(response);
                    }
                } finally {
                    WebsocketActionSupportImpl.clear();
                }
            }
        });
    }

    public Map<String, WebsocketContext> getContextMap() {
        return contextMap;
    }

    @Override
    public void onClose(Session session, CloseReason closeReason) {
        contextMap.remove(session.getRequestParameterMap().get("requestId").get(0));
        final SessionImpl sessionImpl = wrapperMap.remove(session.getId());
        if (sessionImpl != null) {
            try {
                WebsocketActionSupportImpl.setInstance(new WebsocketActionSupportImpl(sessionImpl));
                for (Topic topic : sessionImpl.getCtx().getSpringContext().getTopics().values()) {
                    try {
                        topic.unsubscribe();
                    } catch (InvalidSubscriptionException ise) {
                        // Ignored already unsubscribed
                    }
                }
            } finally {
                WebsocketActionSupportImpl.clear();
                sessionImpl.close();
            }
        }
    }

    @Override
    public void onError(Session session, Throwable thr) {
        if (!(thr instanceof IOException)) {
            thr.printStackTrace();
        }
    }

    protected boolean allowAccess(Session session, WebsocketContext websocketContext) {
        final RpcSpringContext rpcCtx = websocketContext.getSpringContext();
        if (rpcCtx.getParent() != null) {
            try {
                if (rpcCtx.getParent().getBean("springSecurityFilterChain") != null) { // Security active
                    final SecurityContext sc = (SecurityContext) websocketContext.getSecurityContext();
                    if (sc.getAuthentication() == null) {
                        return false;
                    } else {
                        return sc.getAuthentication().isAuthenticated();
                    }
                }
            } catch (NoSuchBeanDefinitionException ex) {
                return true;
            }
        }
        return true;
    }

    /**
     *
     * @param message
     * @return
     */
    private String process(String message, SessionImpl sessionImpl) {
        RpcRequest req = null;
        Object result = null;
        Throwable throwable = null;
        try {
            req = JsonCodec.getInstance().parse(message, RpcRequest.class);
            result = execute(req, sessionImpl.getCtx().getSpringContext());
        } catch (Throwable th) {
            throwable = th;
        }
        if (req != null && req.getId() == null) {
            return null;
        }
        RpcResponse resp;
        if (result instanceof RpcResponse) {
            resp = (RpcResponse) result;
        } else {
            resp = new RpcResponse();
            if (req != null) {
                resp.setId(req.getId());
            }
            resp.setError(ErrorFactory.getError(throwable));
            resp.setResult(result);
        }
        return JsonCodec.getInstance().transform(resp);
    }

    /**
     *
     * @param request
     * @return
     */
    private Object execute(RpcRequest request, RpcSpringContext rpcCtx) throws Exception {
        if (!"2.0".equals(request.getJsonrpc())) {
            throw new InvalidRequestException("Only jsonrpc 2.0 supported");
        }
        String serviceId = request.getMethod();
        Map<String, WebsocketAction> services = rpcCtx.getWebSocketServices();
        if (serviceId == null || !services.containsKey(serviceId)) {
            throw new ServiceNotFoundException();
        }
        WebsocketAction service = services.get(serviceId);
        Object input;
        if (request.getParams() == null) {
            input = null;
        } else {
            Type inputType = service.getInputType();
            JsonSchema inputSchema = JsonCodec.getInstance().getSchema(inputType);
            inputSchema.validate(request.getParams());
            if (inputType.equals(Object.class)) {
                input = request.getParams();
            } else {
                input = JsonCodec.getInstance().load(request.getParams(), Miscellaneous.getClass(inputType));
            }
        }
        return service.execute(input);
    }
}