eastwind.webpush.WebPushHandler.java Source code

Java tutorial

Introduction

Here is the source code for eastwind.webpush.WebPushHandler.java

Source

/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package eastwind.webpush;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.CharsetUtil;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.TimerTask;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.charset.Charset;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;

/**
 * Handles handshakes and messages
 */
class WebPushHandler extends SimpleChannelInboundHandler<Object> {

    private static Logger logger = LoggerFactory.getLogger(WebPushHandler.class);

    private SessionManager sessionManager;

    private Action action;
    private HashedWheelTimer timer;
    private ObjectMapper objectMapper;
    private WebSocketServerHandshaker handshaker;
    private int tickTime;
    private int lost;

    public WebPushHandler(Action action, HashedWheelTimer timer, SessionManager sessionManager,
            ObjectMapper objectMapper, int tickTime) {
        this.action = action;
        this.timer = timer;
        this.sessionManager = sessionManager;
        this.objectMapper = objectMapper;
        this.tickTime = tickTime;
        this.lost = tickTime * 5 / 2;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        ChannelPinger cp = new ChannelPinger(ctx.channel());
        timer.newTimeout(cp, tickTime, TimeUnit.MILLISECONDS);
        super.channelActive(ctx);
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, Object msg) throws JsonProcessingException {
        if (msg instanceof FullHttpRequest) {
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }

    private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) throws JsonProcessingException {
        // Handle a bad request.
        if (!req.decoderResult().isSuccess()) {
            sendHttpResponse(ctx, req,
                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
            return;
        }

        // Allow only GET methods.
        if (req.method() != HttpMethod.GET) {
            sendHttpResponse(ctx, req,
                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
            return;
        }

        String uid = null;
        String uuid = null;
        Channel channel = ctx.channel();

        int q = req.uri().indexOf("?");
        String path = req.uri();
        if (q != -1) {
            path = path.substring(0, q);
        }
        List<String> l = Lists.newLinkedList(Splitter.on("/").omitEmptyStrings().trimResults().split(path));
        if (l.size() >= 2) {
            uid = l.get(0);
            uuid = l.get(1);
            Session s = sessionManager.get(uid, uuid);
            if (s == null) {
                logger.info("expired:{}-{}", uid, uuid);
                sendHttpResponse(ctx, req,
                        new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
            }
            if (l.size() == 2) {
                s.setChannel(channel);
                s.trySendMessages();
            } else {
                String oper = l.get(2);
                handleHttpOper(ctx, req, s, oper);
            }
        } else {
            String params = "";
            if (q != -1) {
                String uri = req.uri();
                if (q < uri.length() - 1) {
                    params = uri.substring(q + 1);
                }
            }
            try {
                uid = action.active(channel.remoteAddress(), params);
            } catch (Throwable th) {
                logger.warn("active:", th);
                sendHttpResponse(ctx, req,
                        new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR,
                                Unpooled.copiedBuffer(th.getClass().getName(), Charset.forName("utf-8"))));
                return;
            }
            if (uid != null) {
                uuid = sessionManager.create(uid).getUuid();
                logger.info("active:{}-{}", uid, uuid);
                SessionGroup sg = sessionManager.get(uid);
                timer.newTimeout(new SessionCleaner(sg), lost, TimeUnit.MILLISECONDS);
            }

            // websocket
            if ("Upgrade".equals(req.headers().get(HttpHeaderNames.CONNECTION))
                    && "websocket".equals(req.headers().get(HttpHeaderNames.UPGRADE))) {
                // Handshake
                WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                        getWebSocketLocation(req, ""), null, true);
                handshaker = wsFactory.newHandshaker(req);
                if (handshaker == null) {
                    WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(channel);
                    return;
                } else {
                    handshaker.handshake(channel, req);
                    UserLite.set(channel, new UserLite(uid, uuid));
                }
            } else {
                String content = String.format("{\"uid\":\"%s\", \"uuid\":\"%s\"}", uid, uuid);
                sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK,
                        Unpooled.copiedBuffer(content, Charset.forName("utf-8"))));
            }
        }
    }

    private void handleHttpOper(ChannelHandlerContext ctx, FullHttpRequest req, Session s, String oper)
            throws JsonProcessingException {
        String uid = s.getUid();
        if (oper.equals("registers")) {
            QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
            List<String> types = decoder.parameters().get("type");
            s.registerTypes(types);
            if (types != null) {
                logger.debug("registers:{}-{}", uid, objectMapper.writeValueAsString(types));
            }
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK,
                    Unpooled.copiedBuffer("{}", Charset.forName("utf-8"))));
        } else if (oper.equals("register")) {
            QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
            List<String> types = decoder.parameters().get("type");
            if (types != null && types.size() > 0) {
                s.registerType(types.get(0));
                logger.debug("register:{}-{}", uid, types.get(0));
            }
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK,
                    Unpooled.copiedBuffer("{}", Charset.forName("utf-8"))));
        } else if (oper.equals("cancel")) {
            s.setCanceled();
            sessionManager.get(uid).remove(s);
            logger.debug("cancel:{}-{}", uid, s.getUuid());
        }
    }

    private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
        // Check for closing frame
        Channel channel = ctx.channel();
        if (frame instanceof CloseWebSocketFrame) {
            handshaker.close(channel, (CloseWebSocketFrame) frame.retain());
            return;
        }

        UserLite u = UserLite.get(channel);
        Session s = sessionManager.get(u.uid, u.uuid);
        if (s == null) {
            UserLite.set(channel, null);
            ctx.writeAndFlush(Message.FORBIDDEN);
            return;
        }
        Stat.setLastRead(channel);

        if (frame instanceof PingWebSocketFrame) {
            channel.write(new PongWebSocketFrame(frame.content().retain()));
            return;
        }

        if (frame instanceof TextWebSocketFrame) {
            TextWebSocketFrame tf = (TextWebSocketFrame) frame;
            ByteBufInputStream is = new ByteBufInputStream(tf.content());
            try {
                handleWsOper(channel, u, s, is);
            } catch (IOException e) {
                e.printStackTrace();
            }

        }
    }

    @SuppressWarnings("unchecked")
    private void handleWsOper(Channel channel, UserLite u, Session s, ByteBufInputStream is)
            throws IOException, JsonParseException, JsonMappingException, JsonProcessingException {
        Message message = objectMapper.readValue(is, Message.class);
        Object data = message.getData();
        String type = message.getType();
        if (type.equals("registers")) {
            s.registerTypes((Collection<String>) data);
            s.setChannel(channel);
            s.trySendMessages();
            logger.debug("registers:{}-{}", u.uid, objectMapper.writeValueAsString(data));
        } else if (type.equals("register")) {
            if (data != null) {
                s.registerType((String) data);
                logger.debug("register:{}-{}", u.uid, data);
            }
        } else if (type.equals("cancel")) {
            logger.debug("cancel:{}-{}", u.uid, u.uuid);
            s.setCanceled();
            sessionManager.get(u.getUid()).remove(s);
        }
    }

    private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
        // Generate an error page if response getStatus code is not OK (200).
        if (res.status().code() != 200) {
            ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
            res.content().writeBytes(buf);
            buf.release();
        }

        // Send the response and close the connection if necessary.
        res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/json; charset=UTF-8");
        res.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
        HttpUtil.setContentLength(res, res.content().readableBytes());
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!HttpUtil.isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        if (cause.getClass().equals(IOException.class)) {
            return;
        }
        cause.printStackTrace();
    }

    private static String getWebSocketLocation(FullHttpRequest req, String group) {
        String location = req.headers().get(HttpHeaderNames.HOST) + group;
        return "ws://" + location;
    }

    private class ChannelPinger implements TimerTask {

        private WeakReference<Channel> channelRef;

        public ChannelPinger(Channel c) {
            this.channelRef = new WeakReference<Channel>(c);
        }

        @Override
        public void run(Timeout timeout) throws Exception {
            Channel c = channelRef.get();
            if (c == null || !c.isActive()) {
                return;
            }
            long lastRead = Stat.getLastRead(c);
            if (lastRead == -1) {
                return;
            }
            long now = System.currentTimeMillis();
            if (now - lastRead > lost) {
                c.close();
                return;
            }
            long diff = now - lastRead;
            if (diff > tickTime) {
                c.writeAndFlush(Message.PING);
                timer.newTimeout(this, tickTime, TimeUnit.MILLISECONDS);
            } else {
                timer.newTimeout(this, tickTime - diff, TimeUnit.MILLISECONDS);
            }
        }
    }

    private class SessionCleaner implements TimerTask {

        private SessionGroup sg;

        public SessionCleaner(SessionGroup sg) {
            this.sg = sg;
        }

        @Override
        public void run(Timeout timeout) throws Exception {
            sg.clean();
            if (sg.size() == 0) {
                logger.info("clean:{}", sg.getUid());
                sg.setRemoved(true);
                if (sg.size() == 0) {
                    sessionManager.remove(sg);
                    return;
                } else {
                    sg.setRemoved(false);
                }
            }
            timer.newTimeout(this, lost, TimeUnit.MILLISECONDS);
        }

    }

}