org.apache.drill.exec.rpc.control.TestCustomTunnel.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.drill.exec.rpc.control.TestCustomTunnel.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.rpc.control;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DrillBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.util.internal.ThreadLocalRandom;

import java.util.Arrays;
import java.util.Random;

import org.apache.drill.BaseTestQuery;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.proto.UserBitShared.QueryId;
import org.apache.drill.exec.rpc.UserRpcException;
import org.apache.drill.exec.rpc.control.ControlTunnel.CustomFuture;
import org.apache.drill.exec.rpc.control.ControlTunnel.CustomTunnel;
import org.apache.drill.exec.rpc.control.Controller.CustomMessageHandler;
import org.apache.drill.exec.rpc.control.Controller.CustomResponse;
import org.apache.drill.exec.server.DrillbitContext;
import org.junit.Test;

public class TestCustomTunnel extends BaseTestQuery {

    private final QueryId expectedId = QueryId.newBuilder().setPart1(ThreadLocalRandom.current().nextLong())
            .setPart2(ThreadLocalRandom.current().nextLong()).build();

    private final ByteBuf buf1;
    private final byte[] expected;

    public TestCustomTunnel() {
        buf1 = UnpooledByteBufAllocator.DEFAULT.buffer(1024);
        Random r = new Random();
        this.expected = new byte[1024];
        r.nextBytes(expected);
        buf1.writeBytes(expected);
    }

    @Test
    public void ensureRoundTrip() throws Exception {

        final DrillbitContext context = getDrillbitContext();
        final TestCustomMessageHandler handler = new TestCustomMessageHandler(context.getEndpoint(), false);
        context.getController().registerCustomHandler(1001, handler, DrillbitEndpoint.PARSER);
        final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint());
        final CustomTunnel<DrillbitEndpoint, QueryId> tunnel = loopbackTunnel.getCustomTunnel(1001,
                DrillbitEndpoint.class, QueryId.PARSER);
        CustomFuture<QueryId> future = tunnel.send(context.getEndpoint());
        assertEquals(expectedId, future.get());
    }

    @Test
    public void ensureRoundTripBytes() throws Exception {
        final DrillbitContext context = getDrillbitContext();
        final TestCustomMessageHandler handler = new TestCustomMessageHandler(context.getEndpoint(), true);
        context.getController().registerCustomHandler(1002, handler, DrillbitEndpoint.PARSER);
        final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint());
        final CustomTunnel<DrillbitEndpoint, QueryId> tunnel = loopbackTunnel.getCustomTunnel(1002,
                DrillbitEndpoint.class, QueryId.PARSER);
        buf1.retain();
        CustomFuture<QueryId> future = tunnel.send(context.getEndpoint(), buf1);
        assertEquals(expectedId, future.get());
        byte[] actual = new byte[1024];
        future.getBuffer().getBytes(0, actual);
        future.getBuffer().release();
        assertTrue(Arrays.equals(expected, actual));
    }

    private class TestCustomMessageHandler implements CustomMessageHandler<DrillbitEndpoint, QueryId> {
        private DrillbitEndpoint expectedValue;
        private final boolean returnBytes;

        public TestCustomMessageHandler(DrillbitEndpoint expectedValue, boolean returnBytes) {
            super();
            this.expectedValue = expectedValue;
            this.returnBytes = returnBytes;
        }

        @Override
        public CustomResponse<QueryId> onMessage(DrillbitEndpoint pBody, DrillBuf dBody) throws UserRpcException {

            if (!expectedValue.equals(pBody)) {
                throw new UserRpcException(expectedValue, "Invalid expected downstream value.",
                        new IllegalStateException());
            }

            if (returnBytes) {
                byte[] actual = new byte[1024];
                dBody.getBytes(0, actual);
                if (!Arrays.equals(expected, actual)) {
                    throw new UserRpcException(expectedValue, "Invalid expected downstream value.",
                            new IllegalStateException());
                }
            }

            return new CustomResponse<QueryId>() {

                @Override
                public QueryId getMessage() {
                    return expectedId;
                }

                @Override
                public ByteBuf[] getBodies() {
                    if (returnBytes) {
                        buf1.retain();
                        return new ByteBuf[] { buf1 };
                    } else {
                        return null;
                    }
                }

            };
        }
    }

    @Test
    public void ensureRoundTripJackson() throws Exception {
        final DrillbitContext context = getDrillbitContext();
        final MesgA mesgA = new MesgA();
        mesgA.fieldA = "123";
        mesgA.fieldB = "okra";

        final TestCustomMessageHandlerJackson handler = new TestCustomMessageHandlerJackson(mesgA);
        context.getController().registerCustomHandler(1003, handler,
                new ControlTunnel.JacksonSerDe<MesgA>(MesgA.class),
                new ControlTunnel.JacksonSerDe<MesgB>(MesgB.class));
        final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint());
        final CustomTunnel<MesgA, MesgB> tunnel = loopbackTunnel.getCustomTunnel(1003,
                new ControlTunnel.JacksonSerDe<MesgA>(MesgA.class),
                new ControlTunnel.JacksonSerDe<MesgB>(MesgB.class));
        CustomFuture<MesgB> future = tunnel.send(mesgA);
        assertEquals(expectedB, future.get());
    }

    private MesgB expectedB = new MesgB().set("hello", "bye", "friend");

    public static class MesgA {
        public String fieldA;
        public String fieldB;

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((fieldA == null) ? 0 : fieldA.hashCode());
            result = prime * result + ((fieldB == null) ? 0 : fieldB.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            MesgA other = (MesgA) obj;
            if (fieldA == null) {
                if (other.fieldA != null) {
                    return false;
                }
            } else if (!fieldA.equals(other.fieldA)) {
                return false;
            }
            if (fieldB == null) {
                if (other.fieldB != null) {
                    return false;
                }
            } else if (!fieldB.equals(other.fieldB)) {
                return false;
            }
            return true;
        }

    }

    public static class MesgB {
        public String fieldA;
        public String fieldB;
        public String fieldC;

        public MesgB set(String a, String b, String c) {
            fieldA = a;
            fieldB = b;
            fieldC = c;
            return this;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((fieldA == null) ? 0 : fieldA.hashCode());
            result = prime * result + ((fieldB == null) ? 0 : fieldB.hashCode());
            result = prime * result + ((fieldC == null) ? 0 : fieldC.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            MesgB other = (MesgB) obj;
            if (fieldA == null) {
                if (other.fieldA != null) {
                    return false;
                }
            } else if (!fieldA.equals(other.fieldA)) {
                return false;
            }
            if (fieldB == null) {
                if (other.fieldB != null) {
                    return false;
                }
            } else if (!fieldB.equals(other.fieldB)) {
                return false;
            }
            if (fieldC == null) {
                if (other.fieldC != null) {
                    return false;
                }
            } else if (!fieldC.equals(other.fieldC)) {
                return false;
            }
            return true;
        }

    }

    private class TestCustomMessageHandlerJackson implements CustomMessageHandler<MesgA, MesgB> {
        private MesgA expectedValue;

        public TestCustomMessageHandlerJackson(MesgA expectedValue) {
            super();
            this.expectedValue = expectedValue;
        }

        @Override
        public CustomResponse<MesgB> onMessage(MesgA pBody, DrillBuf dBody) throws UserRpcException {

            if (!expectedValue.equals(pBody)) {
                throw new UserRpcException(DrillbitEndpoint.getDefaultInstance(),
                        "Invalid expected downstream value.", new IllegalStateException());
            }

            return new CustomResponse<MesgB>() {

                @Override
                public MesgB getMessage() {
                    return expectedB;
                }

                @Override
                public ByteBuf[] getBodies() {
                    return null;
                }

            };
        }
    }
}