org.apache.hadoop.hdfs.qjournal.client.TestQJMWithFaults.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hdfs.qjournal.client.TestQJMWithFaults.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hdfs.qjournal.client;

import static org.apache.hadoop.hdfs.qjournal.QJMTestUtil.FAKE_NSINFO;
import static org.apache.hadoop.hdfs.qjournal.QJMTestUtil.JID;
import static org.apache.hadoop.hdfs.qjournal.QJMTestUtil.writeSegment;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SortedSet;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.logging.impl.Log4JLogger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.hdfs.qjournal.MiniJournalCluster;
import org.apache.hadoop.hdfs.qjournal.QJMTestUtil;
import org.apache.hadoop.hdfs.qjournal.protocol.QJournalProtocol;
import org.apache.hadoop.hdfs.qjournal.server.JournalFaultInjector;
import org.apache.hadoop.hdfs.server.namenode.EditLogFileOutputStream;
import org.apache.hadoop.hdfs.server.namenode.EditLogOutputStream;
import org.apache.hadoop.hdfs.server.namenode.NameNodeLayoutVersion;
import org.apache.hadoop.hdfs.server.protocol.NamespaceInfo;
import org.apache.hadoop.hdfs.util.Holder;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.ipc.ProtobufRpcEngine;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.log4j.Level;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.MoreExecutors;

public class TestQJMWithFaults {
    private static final Log LOG = LogFactory.getLog(TestQJMWithFaults.class);

    private static final String RAND_SEED_PROPERTY = "TestQJMWithFaults.random-seed";

    private static final int NUM_WRITER_ITERS = 500;
    private static final int SEGMENTS_PER_WRITER = 2;

    private static final Configuration conf = new Configuration();

    static {
        // Don't retry connections - it just slows down the tests.
        conf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_KEY, 0);

        // Make tests run faster by avoiding fsync()
        EditLogFileOutputStream.setShouldSkipFsyncForTesting(true);
    }

    // Set up fault injection mock.
    private static final JournalFaultInjector faultInjector = JournalFaultInjector.instance = Mockito
            .mock(JournalFaultInjector.class);

    /**
     * Run through the creation of a log without any faults injected,
     * and count how many RPCs are made to each node. This sets the
     * bounds for the other test cases, so they can exhaustively explore
     * the space of potential failures.
     */
    private static long determineMaxIpcNumber() throws Exception {
        Configuration conf = new Configuration();
        MiniJournalCluster cluster = new MiniJournalCluster.Builder(conf).build();
        QuorumJournalManager qjm = null;
        long ret;
        try {
            qjm = createInjectableQJM(cluster);
            qjm.format(FAKE_NSINFO);
            doWorkload(cluster, qjm);

            SortedSet<Integer> ipcCounts = Sets.newTreeSet();
            for (AsyncLogger l : qjm.getLoggerSetForTests().getLoggersForTests()) {
                InvocationCountingChannel ch = (InvocationCountingChannel) l;
                ch.waitForAllPendingCalls();
                ipcCounts.add(ch.getRpcCount());
            }

            // All of the loggers should have sent the same number of RPCs, since there
            // were no failures.
            assertEquals(1, ipcCounts.size());

            ret = ipcCounts.first();
            LOG.info("Max IPC count = " + ret);
        } finally {
            IOUtils.closeStream(qjm);
            cluster.shutdown();
        }
        return ret;
    }

    /**
     * Sets up two of the nodes to each drop a single RPC, at all
     * possible combinations of RPCs. This may result in the
     * active writer failing to write. After this point, a new writer
     * should be able to recover and continue writing without
     * data loss.
     */
    @Test
    public void testRecoverAfterDoubleFailures() throws Exception {
        final long MAX_IPC_NUMBER = determineMaxIpcNumber();

        for (int failA = 1; failA <= MAX_IPC_NUMBER; failA++) {
            for (int failB = 1; failB <= MAX_IPC_NUMBER; failB++) {
                String injectionStr = "(" + failA + ", " + failB + ")";

                LOG.info("\n\n-------------------------------------------\n" + "Beginning test, failing at "
                        + injectionStr + "\n" + "-------------------------------------------\n\n");

                MiniJournalCluster cluster = new MiniJournalCluster.Builder(conf).build();
                QuorumJournalManager qjm = null;
                try {
                    qjm = createInjectableQJM(cluster);
                    qjm.format(FAKE_NSINFO);
                    List<AsyncLogger> loggers = qjm.getLoggerSetForTests().getLoggersForTests();
                    failIpcNumber(loggers.get(0), failA);
                    failIpcNumber(loggers.get(1), failB);
                    int lastAckedTxn = doWorkload(cluster, qjm);

                    if (lastAckedTxn < 6) {
                        LOG.info("Failed after injecting failures at " + injectionStr
                                + ". This is expected since we injected a failure in the " + "majority.");
                    }
                    qjm.close();
                    qjm = null;

                    // Now should be able to recover
                    qjm = createInjectableQJM(cluster);
                    long lastRecoveredTxn = QJMTestUtil.recoverAndReturnLastTxn(qjm);
                    assertTrue(lastRecoveredTxn >= lastAckedTxn);

                    writeSegment(cluster, qjm, lastRecoveredTxn + 1, 3, true);
                } catch (Throwable t) {
                    // Test failure! Rethrow with the test setup info so it can be
                    // easily triaged.
                    throw new RuntimeException("Test failed with injection: " + injectionStr, t);
                } finally {
                    cluster.shutdown();
                    cluster = null;
                    IOUtils.closeStream(qjm);
                    qjm = null;
                }
            }
        }
    }

    /**
     * Test case in which three JournalNodes randomly flip flop between
     * up and down states every time they get an RPC.
     * 
     * The writer keeps track of the latest ACKed edit, and on every
     * recovery operation, ensures that it recovers at least to that
     * point or higher. Since at any given point, a majority of JNs
     * may be injecting faults, any writer operation is allowed to fail,
     * so long as the exception message indicates it failed due to injected
     * faults.
     * 
     * Given a random seed, the test should be entirely deterministic.
     */
    @Test
    public void testRandomized() throws Exception {
        long seed;
        Long userSpecifiedSeed = Long.getLong(RAND_SEED_PROPERTY);
        if (userSpecifiedSeed != null) {
            LOG.info("Using seed specified in system property");
            seed = userSpecifiedSeed;

            // If the user specifies a seed, then we should gather all the
            // IPC trace information so that debugging is easier. This makes
            // the test run about 25% slower otherwise.
            ((Log4JLogger) ProtobufRpcEngine.LOG).getLogger().setLevel(Level.ALL);
        } else {
            seed = new Random().nextLong();
        }
        LOG.info("Random seed: " + seed);

        Random r = new Random(seed);

        MiniJournalCluster cluster = new MiniJournalCluster.Builder(conf).build();

        // Format the cluster using a non-faulty QJM.
        QuorumJournalManager qjmForInitialFormat = createInjectableQJM(cluster);
        qjmForInitialFormat.format(FAKE_NSINFO);
        qjmForInitialFormat.close();

        try {
            long txid = 0;
            long lastAcked = 0;

            for (int i = 0; i < NUM_WRITER_ITERS; i++) {
                LOG.info("Starting writer " + i + "\n-------------------");

                QuorumJournalManager qjm = createRandomFaultyQJM(cluster, r);
                try {
                    long recovered;
                    try {
                        recovered = QJMTestUtil.recoverAndReturnLastTxn(qjm);
                    } catch (Throwable t) {
                        LOG.info("Failed recovery", t);
                        checkException(t);
                        continue;
                    }
                    assertTrue(
                            "Recovered only up to txnid " + recovered + " but had gotten an ack for " + lastAcked,
                            recovered >= lastAcked);

                    txid = recovered + 1;

                    // Periodically purge old data on disk so it's easier to look
                    // at failure cases.
                    if (txid > 100 && i % 10 == 1) {
                        qjm.purgeLogsOlderThan(txid - 100);
                    }

                    Holder<Throwable> thrown = new Holder<Throwable>(null);
                    for (int j = 0; j < SEGMENTS_PER_WRITER; j++) {
                        lastAcked = writeSegmentUntilCrash(cluster, qjm, txid, 4, thrown);
                        if (thrown.held != null) {
                            LOG.info("Failed write", thrown.held);
                            checkException(thrown.held);
                            break;
                        }
                        txid += 4;
                    }
                } finally {
                    qjm.close();
                }
            }
        } finally {
            cluster.shutdown();
        }
    }

    private void checkException(Throwable t) {
        GenericTestUtils.assertExceptionContains("Injected", t);
        if (t.toString().contains("AssertionError")) {
            throw new RuntimeException("Should never see AssertionError in fault test!", t);
        }
    }

    private long writeSegmentUntilCrash(MiniJournalCluster cluster, QuorumJournalManager qjm, long txid,
            int numTxns, Holder<Throwable> thrown) {

        long firstTxId = txid;
        long lastAcked = txid - 1;
        try {
            EditLogOutputStream stm = qjm.startLogSegment(txid, NameNodeLayoutVersion.CURRENT_LAYOUT_VERSION);

            for (int i = 0; i < numTxns; i++) {
                QJMTestUtil.writeTxns(stm, txid++, 1);
                lastAcked++;
            }

            stm.close();
            qjm.finalizeLogSegment(firstTxId, lastAcked);
        } catch (Throwable t) {
            thrown.held = t;
        }
        return lastAcked;
    }

    /**
     * Run a simple workload of becoming the active writer and writing
     * two log segments: 1-3 and 4-6.
     */
    private static int doWorkload(MiniJournalCluster cluster, QuorumJournalManager qjm) throws IOException {
        int lastAcked = 0;
        try {
            qjm.recoverUnfinalizedSegments();
            writeSegment(cluster, qjm, 1, 3, true);
            lastAcked = 3;
            writeSegment(cluster, qjm, 4, 3, true);
            lastAcked = 6;
        } catch (QuorumException qe) {
            LOG.info("Failed to write at txid " + lastAcked, qe);
        }
        return lastAcked;
    }

    /**
     * Inject a failure at the given IPC number, such that the JN never
     * receives the RPC. The client side sees an IOException. Future
     * IPCs after this number will be received as usual.
     */
    private void failIpcNumber(AsyncLogger logger, int idx) {
        ((InvocationCountingChannel) logger).failIpcNumber(idx);
    }

    private static class RandomFaultyChannel extends IPCLoggerChannel {
        private final Random random;
        private final float injectionProbability = 0.1f;
        private boolean isUp = true;

        public RandomFaultyChannel(Configuration conf, NamespaceInfo nsInfo, String journalId,
                InetSocketAddress addr, long seed) {
            super(conf, nsInfo, journalId, addr);
            this.random = new Random(seed);
        }

        @Override
        protected QJournalProtocol createProxy() throws IOException {
            QJournalProtocol realProxy = super.createProxy();
            return mockProxy(new WrapEveryCall<Object>(realProxy) {
                @Override
                void beforeCall(InvocationOnMock invocation) throws Exception {
                    if (random.nextFloat() < injectionProbability) {
                        isUp = !isUp;
                        LOG.info("transitioned " + addr + " to " + (isUp ? "up" : "down"));
                    }

                    if (!isUp) {
                        throw new IOException("Injected - faking being down");
                    }

                    if (invocation.getMethod().getName().equals("acceptRecovery")) {
                        if (random.nextFloat() < injectionProbability) {
                            Mockito.doThrow(new IOException("Injected - faking fault before persisting paxos data"))
                                    .when(faultInjector).beforePersistPaxosData();
                        } else if (random.nextFloat() < injectionProbability) {
                            Mockito.doThrow(new IOException("Injected - faking fault after persisting paxos data"))
                                    .when(faultInjector).afterPersistPaxosData();
                        }
                    }
                }

                @Override
                public void afterCall(InvocationOnMock invocation, boolean succeeded) {
                    Mockito.reset(faultInjector);
                }
            });
        }

        @Override
        protected ExecutorService createSingleThreadExecutor() {
            return MoreExecutors.sameThreadExecutor();
        }
    }

    private static class InvocationCountingChannel extends IPCLoggerChannel {
        private int rpcCount = 0;
        private final Map<Integer, Callable<Void>> injections = Maps.newHashMap();

        public InvocationCountingChannel(Configuration conf, NamespaceInfo nsInfo, String journalId,
                InetSocketAddress addr) {
            super(conf, nsInfo, journalId, addr);
        }

        int getRpcCount() {
            return rpcCount;
        }

        void failIpcNumber(final int idx) {
            Preconditions.checkArgument(idx > 0, "id must be positive");
            inject(idx, new Callable<Void>() {
                @Override
                public Void call() throws Exception {
                    throw new IOException("injected failed IPC at " + idx);
                }
            });
        }

        private void inject(int beforeRpcNumber, Callable<Void> injectedCode) {
            injections.put(beforeRpcNumber, injectedCode);
        }

        @Override
        protected QJournalProtocol createProxy() throws IOException {
            final QJournalProtocol realProxy = super.createProxy();
            QJournalProtocol mock = mockProxy(new WrapEveryCall<Object>(realProxy) {
                void beforeCall(InvocationOnMock invocation) throws Exception {
                    rpcCount++;
                    String callStr = "[" + addr + "] " + invocation.getMethod().getName() + "("
                            + Joiner.on(", ").join(invocation.getArguments()) + ")";

                    Callable<Void> inject = injections.get(rpcCount);
                    if (inject != null) {
                        LOG.info("Injecting code before IPC #" + rpcCount + ": " + callStr);
                        inject.call();
                    } else {
                        LOG.info("IPC call #" + rpcCount + ": " + callStr);
                    }
                }
            });
            return mock;
        }
    }

    private static QJournalProtocol mockProxy(WrapEveryCall<Object> wrapper) throws IOException {
        QJournalProtocol mock = Mockito.mock(QJournalProtocol.class,
                Mockito.withSettings().defaultAnswer(wrapper).extraInterfaces(Closeable.class));
        return mock;
    }

    private static abstract class WrapEveryCall<T> implements Answer<T> {
        private final Object realObj;

        WrapEveryCall(Object realObj) {
            this.realObj = realObj;
        }

        @SuppressWarnings("unchecked")
        @Override
        public T answer(InvocationOnMock invocation) throws Throwable {
            // Don't want to inject an error on close() since that isn't
            // actually an IPC call!
            if (!Closeable.class.equals(invocation.getMethod().getDeclaringClass())) {
                beforeCall(invocation);
            }
            boolean success = false;
            try {
                T ret = (T) invocation.getMethod().invoke(realObj, invocation.getArguments());
                success = true;
                return ret;
            } catch (InvocationTargetException ite) {
                throw ite.getCause();
            } finally {
                afterCall(invocation, success);
            }
        }

        abstract void beforeCall(InvocationOnMock invocation) throws Exception;

        void afterCall(InvocationOnMock invocation, boolean succeeded) {
        }
    }

    private static QuorumJournalManager createInjectableQJM(MiniJournalCluster cluster)
            throws IOException, URISyntaxException {
        AsyncLogger.Factory spyFactory = new AsyncLogger.Factory() {
            @Override
            public AsyncLogger createLogger(Configuration conf, NamespaceInfo nsInfo, String journalId,
                    InetSocketAddress addr) {
                return new InvocationCountingChannel(conf, nsInfo, journalId, addr);
            }
        };
        return new QuorumJournalManager(conf, cluster.getQuorumJournalURI(JID), FAKE_NSINFO, spyFactory);
    }

    private static QuorumJournalManager createRandomFaultyQJM(MiniJournalCluster cluster,
            final Random seedGenerator) throws IOException, URISyntaxException {

        AsyncLogger.Factory spyFactory = new AsyncLogger.Factory() {
            @Override
            public AsyncLogger createLogger(Configuration conf, NamespaceInfo nsInfo, String journalId,
                    InetSocketAddress addr) {
                return new RandomFaultyChannel(conf, nsInfo, journalId, addr, seedGenerator.nextLong());
            }
        };
        return new QuorumJournalManager(conf, cluster.getQuorumJournalURI(JID), FAKE_NSINFO, spyFactory);
    }

}