com.spotify.netty4.handler.codec.zmtp.ZMTPFramingEncoder.java Source code

Java tutorial

Introduction

Here is the source code for com.spotify.netty4.handler.codec.zmtp.ZMTPFramingEncoder.java

Source

/*
 * Copyright (c) 2012-2013 Spotify AB
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.spotify.netty4.handler.codec.zmtp;

import java.util.ArrayList;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.ReferenceCountUtil;

/**
 * Netty ZMTP encoder.
 */
class ZMTPFramingEncoder extends ChannelOutboundHandlerAdapter {

    private final ZMTPEncoder encoder;

    private final List<Object> messages = new ArrayList<Object>();
    private final List<ChannelPromise> promises = new ArrayList<ChannelPromise>();
    private ZMTPWriter writer;
    private ZMTPEstimator estimator;

    ZMTPFramingEncoder(final ZMTPSession session, final ZMTPEncoder encoder) {
        if (session == null) {
            throw new NullPointerException("session");
        }
        if (encoder == null) {
            throw new NullPointerException("encoder");
        }
        this.encoder = encoder;
        this.writer = ZMTPWriter.create(session.negotiatedVersion());
        this.estimator = ZMTPEstimator.create(session.negotiatedVersion());
    }

    public ZMTPFramingEncoder(final ZMTPWireFormat wireFormat, final ZMTPEncoder encoder) {
        if (wireFormat == null) {
            throw new NullPointerException("wireFormat");
        }
        if (encoder == null) {
            throw new NullPointerException("encoder");
        }
        this.encoder = encoder;
        this.writer = new ZMTPWriter(wireFormat);
        this.estimator = new ZMTPEstimator(wireFormat);
    }

    @Override
    public void handlerRemoved(final ChannelHandlerContext ctx) {
        encoder.close();
    }

    @Override
    public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
        messages.add(msg);
        promises.add(promise);
    }

    @Override
    public void flush(final ChannelHandlerContext ctx) throws Exception {
        if (messages == null) {
            return;
        }
        estimator.reset();
        for (final Object message : messages) {
            encoder.estimate(message, estimator);
        }
        final ByteBuf output = ctx.alloc().buffer(estimator.size());
        writer.reset(output);
        for (final Object message : messages) {
            encoder.encode(message, writer);
            ReferenceCountUtil.release(message);
        }
        final ChannelPromise aggregate = new AggregatePromise(ctx.channel(), promises);
        messages.clear();
        promises.clear();
        ctx.write(output, aggregate);
        ctx.flush();
    }

    private static class AggregatePromise extends DefaultChannelPromise {

        private final ChannelPromise[] promises;

        private AggregatePromise(final Channel channel, final List<ChannelPromise> promises) {
            super(channel);
            this.promises = promises.toArray(new ChannelPromise[promises.size()]);
        }

        @Override
        public ChannelPromise setSuccess(final Void result) {
            super.setSuccess(result);
            for (final ChannelPromise promise : promises) {
                promise.setSuccess(result);
            }
            return this;
        }

        @Override
        public boolean trySuccess() {
            final boolean result = super.trySuccess();
            for (final ChannelPromise promise : promises) {
                promise.trySuccess();
            }
            return result;
        }

        @Override
        public ChannelPromise setFailure(final Throwable cause) {
            super.setFailure(cause);
            for (final ChannelPromise promise : promises) {
                promise.setFailure(cause);
            }
            return this;
        }
    }
}