Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.net.unix; import java.io.File; import java.io.FileDescriptor; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousCloseException; import java.nio.channels.ClosedChannelException; import java.util.Arrays; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.apache.commons.lang.exception.ExceptionUtils; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.net.unix.DomainSocket.DomainChannel; import org.apache.hadoop.test.GenericTestUtils; import org.apache.hadoop.util.Shell; import com.google.common.io.Files; public class TestDomainSocket { private static TemporarySocketDirectory sockDir; @BeforeClass public static void init() { sockDir = new TemporarySocketDirectory(); DomainSocket.disableBindPathValidation(); } @AfterClass public static void shutdown() throws IOException { sockDir.close(); } @Before public void before() { Assume.assumeTrue(DomainSocket.getLoadingFailureReason() == null); } /** * Test that we can create a socket and close it, even if it hasn't been * opened. * * @throws IOException */ @Test(timeout = 180000) public void testSocketCreateAndClose() throws IOException { DomainSocket serv = DomainSocket .bindAndListen(new File(sockDir.getDir(), "test_sock_create_and_close").getAbsolutePath()); serv.close(); } /** * Test DomainSocket path setting and getting. * * @throws IOException */ @Test(timeout = 180000) public void testSocketPathSetGet() throws IOException { Assert.assertEquals("/var/run/hdfs/sock.100", DomainSocket.getEffectivePath("/var/run/hdfs/sock._PORT", 100)); } /** * Test that we get a read result of -1 on EOF. * * @throws IOException */ @Test(timeout = 180000) public void testSocketReadEof() throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "testSocketReadEof").getAbsolutePath(); final DomainSocket serv = DomainSocket.bindAndListen(TEST_PATH); ExecutorService exeServ = Executors.newSingleThreadExecutor(); Callable<Void> callable = new Callable<Void>() { public Void call() { DomainSocket conn; try { conn = serv.accept(); } catch (IOException e) { throw new RuntimeException("unexpected IOException", e); } byte buf[] = new byte[100]; for (int i = 0; i < buf.length; i++) { buf[i] = 0; } try { Assert.assertEquals(-1, conn.getInputStream().read()); } catch (IOException e) { throw new RuntimeException("unexpected IOException", e); } return null; } }; Future<Void> future = exeServ.submit(callable); DomainSocket conn = DomainSocket.connect(serv.getPath()); Thread.sleep(50); conn.close(); serv.close(); future.get(2, TimeUnit.MINUTES); } /** * Test that if one thread is blocking in a read or write operation, another * thread can close the socket and stop the accept. * * @throws IOException */ @Test(timeout = 180000) public void testSocketAcceptAndClose() throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "test_sock_accept_and_close").getAbsolutePath(); final DomainSocket serv = DomainSocket.bindAndListen(TEST_PATH); ExecutorService exeServ = Executors.newSingleThreadExecutor(); Callable<Void> callable = new Callable<Void>() { public Void call() { try { serv.accept(); throw new RuntimeException("expected the accept() to be " + "interrupted and fail"); } catch (AsynchronousCloseException e) { return null; } catch (IOException e) { throw new RuntimeException("unexpected IOException", e); } } }; Future<Void> future = exeServ.submit(callable); Thread.sleep(500); serv.close(); future.get(2, TimeUnit.MINUTES); } /** * Test that we get an AsynchronousCloseException when the DomainSocket * we're using is closed during a read or write operation. * * @throws IOException */ private void testAsyncCloseDuringIO(final boolean closeDuringWrite) throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "testAsyncCloseDuringIO(" + closeDuringWrite + ")") .getAbsolutePath(); final DomainSocket serv = DomainSocket.bindAndListen(TEST_PATH); ExecutorService exeServ = Executors.newFixedThreadPool(2); Callable<Void> serverCallable = new Callable<Void>() { public Void call() { DomainSocket serverConn = null; try { serverConn = serv.accept(); byte buf[] = new byte[100]; for (int i = 0; i < buf.length; i++) { buf[i] = 0; } // The server just continues either writing or reading until someone // asynchronously closes the client's socket. At that point, all our // reads return EOF, and writes get a socket error. if (closeDuringWrite) { try { while (true) { serverConn.getOutputStream().write(buf); } } catch (IOException e) { } } else { do { ; } while (serverConn.getInputStream().read(buf, 0, buf.length) != -1); } } catch (IOException e) { throw new RuntimeException("unexpected IOException", e); } finally { IOUtils.cleanup(DomainSocket.LOG, serverConn); } return null; } }; Future<Void> serverFuture = exeServ.submit(serverCallable); final DomainSocket clientConn = DomainSocket.connect(serv.getPath()); Callable<Void> clientCallable = new Callable<Void>() { public Void call() { // The client writes or reads until another thread // asynchronously closes the socket. At that point, we should // get ClosedChannelException, or possibly its subclass // AsynchronousCloseException. byte buf[] = new byte[100]; for (int i = 0; i < buf.length; i++) { buf[i] = 0; } try { if (closeDuringWrite) { while (true) { clientConn.getOutputStream().write(buf); } } else { while (true) { clientConn.getInputStream().read(buf, 0, buf.length); } } } catch (ClosedChannelException e) { return null; } catch (IOException e) { throw new RuntimeException("unexpected IOException", e); } } }; Future<Void> clientFuture = exeServ.submit(clientCallable); Thread.sleep(500); clientConn.close(); serv.close(); clientFuture.get(2, TimeUnit.MINUTES); serverFuture.get(2, TimeUnit.MINUTES); } @Test(timeout = 180000) public void testAsyncCloseDuringWrite() throws Exception { testAsyncCloseDuringIO(true); } @Test(timeout = 180000) public void testAsyncCloseDuringRead() throws Exception { testAsyncCloseDuringIO(false); } /** * Test that attempting to connect to an invalid path doesn't work. * * @throws IOException */ @Test(timeout = 180000) public void testInvalidOperations() throws IOException { try { DomainSocket.connect(new File(sockDir.getDir(), "test_sock_invalid_operation").getAbsolutePath()); } catch (IOException e) { GenericTestUtils.assertExceptionContains("connect(2) error: ", e); } } /** * Test setting some server options. * * @throws IOException */ @Test(timeout = 180000) public void testServerOptions() throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "test_sock_server_options").getAbsolutePath(); DomainSocket serv = DomainSocket.bindAndListen(TEST_PATH); try { // Let's set a new receive buffer size int bufSize = serv.getAttribute(DomainSocket.RECEIVE_BUFFER_SIZE); int newBufSize = bufSize / 2; serv.setAttribute(DomainSocket.RECEIVE_BUFFER_SIZE, newBufSize); int nextBufSize = serv.getAttribute(DomainSocket.RECEIVE_BUFFER_SIZE); Assert.assertEquals(newBufSize, nextBufSize); // Let's set a server timeout int newTimeout = 1000; serv.setAttribute(DomainSocket.RECEIVE_TIMEOUT, newTimeout); int nextTimeout = serv.getAttribute(DomainSocket.RECEIVE_TIMEOUT); Assert.assertEquals(newTimeout, nextTimeout); try { serv.accept(); Assert.fail("expected the accept() to time out and fail"); } catch (SocketTimeoutException e) { GenericTestUtils.assertExceptionContains("accept(2) error: ", e); } } finally { serv.close(); Assert.assertFalse(serv.isOpen()); } } /** * A Throwable representing success. * * We can't use null to represent this, because you cannot insert null into * ArrayBlockingQueue. */ static class Success extends Throwable { private static final long serialVersionUID = 1L; } static interface WriteStrategy { /** * Initialize a WriteStrategy object from a Socket. */ public void init(DomainSocket s) throws IOException; /** * Write some bytes. */ public void write(byte b[]) throws IOException; } static class OutputStreamWriteStrategy implements WriteStrategy { private OutputStream outs = null; public void init(DomainSocket s) throws IOException { outs = s.getOutputStream(); } public void write(byte b[]) throws IOException { outs.write(b); } } abstract static class ReadStrategy { /** * Initialize a ReadStrategy object from a DomainSocket. */ public abstract void init(DomainSocket s) throws IOException; /** * Read some bytes. */ public abstract int read(byte b[], int off, int length) throws IOException; public void readFully(byte buf[], int off, int len) throws IOException { int toRead = len; while (toRead > 0) { int ret = read(buf, off, toRead); if (ret < 0) { throw new IOException("Premature EOF from inputStream"); } toRead -= ret; off += ret; } } } static class InputStreamReadStrategy extends ReadStrategy { private InputStream ins = null; @Override public void init(DomainSocket s) throws IOException { ins = s.getInputStream(); } @Override public int read(byte b[], int off, int length) throws IOException { return ins.read(b, off, length); } } static class DirectByteBufferReadStrategy extends ReadStrategy { private DomainChannel ch = null; @Override public void init(DomainSocket s) throws IOException { ch = s.getChannel(); } @Override public int read(byte b[], int off, int length) throws IOException { ByteBuffer buf = ByteBuffer.allocateDirect(b.length); int nread = ch.read(buf); if (nread < 0) return nread; buf.flip(); buf.get(b, off, nread); return nread; } } static class ArrayBackedByteBufferReadStrategy extends ReadStrategy { private DomainChannel ch = null; @Override public void init(DomainSocket s) throws IOException { ch = s.getChannel(); } @Override public int read(byte b[], int off, int length) throws IOException { ByteBuffer buf = ByteBuffer.wrap(b); int nread = ch.read(buf); if (nread < 0) return nread; buf.flip(); buf.get(b, off, nread); return nread; } } /** * Test a simple client/server interaction. * * @throws IOException */ void testClientServer1(final Class<? extends WriteStrategy> writeStrategyClass, final Class<? extends ReadStrategy> readStrategyClass, final DomainSocket preConnectedSockets[]) throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "test_sock_client_server1").getAbsolutePath(); final byte clientMsg1[] = new byte[] { 0x1, 0x2, 0x3, 0x4, 0x5, 0x6 }; final byte serverMsg1[] = new byte[] { 0x9, 0x8, 0x7, 0x6, 0x5 }; final byte clientMsg2 = 0x45; final ArrayBlockingQueue<Throwable> threadResults = new ArrayBlockingQueue<Throwable>(2); final DomainSocket serv = (preConnectedSockets != null) ? null : DomainSocket.bindAndListen(TEST_PATH); Thread serverThread = new Thread() { public void run() { // Run server DomainSocket conn = null; try { conn = preConnectedSockets != null ? preConnectedSockets[0] : serv.accept(); byte in1[] = new byte[clientMsg1.length]; ReadStrategy reader = readStrategyClass.newInstance(); reader.init(conn); reader.readFully(in1, 0, in1.length); Assert.assertTrue(Arrays.equals(clientMsg1, in1)); WriteStrategy writer = writeStrategyClass.newInstance(); writer.init(conn); writer.write(serverMsg1); InputStream connInputStream = conn.getInputStream(); int in2 = connInputStream.read(); Assert.assertEquals((int) clientMsg2, in2); conn.close(); } catch (Throwable e) { threadResults.add(e); Assert.fail(e.getMessage()); } threadResults.add(new Success()); } }; serverThread.start(); Thread clientThread = new Thread() { public void run() { try { DomainSocket client = preConnectedSockets != null ? preConnectedSockets[1] : DomainSocket.connect(TEST_PATH); WriteStrategy writer = writeStrategyClass.newInstance(); writer.init(client); writer.write(clientMsg1); ReadStrategy reader = readStrategyClass.newInstance(); reader.init(client); byte in1[] = new byte[serverMsg1.length]; reader.readFully(in1, 0, in1.length); Assert.assertTrue(Arrays.equals(serverMsg1, in1)); OutputStream clientOutputStream = client.getOutputStream(); clientOutputStream.write(clientMsg2); client.close(); } catch (Throwable e) { threadResults.add(e); } threadResults.add(new Success()); } }; clientThread.start(); for (int i = 0; i < 2; i++) { Throwable t = threadResults.take(); if (!(t instanceof Success)) { Assert.fail(t.getMessage() + ExceptionUtils.getStackTrace(t)); } } serverThread.join(120000); clientThread.join(120000); if (serv != null) { serv.close(); } } @Test(timeout = 180000) public void testClientServerOutStreamInStream() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, InputStreamReadStrategy.class, null); } @Test(timeout = 180000) public void testClientServerOutStreamInStreamWithSocketpair() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, InputStreamReadStrategy.class, DomainSocket.socketpair()); } @Test(timeout = 180000) public void testClientServerOutStreamInDbb() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, DirectByteBufferReadStrategy.class, null); } @Test(timeout = 180000) public void testClientServerOutStreamInDbbWithSocketpair() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, DirectByteBufferReadStrategy.class, DomainSocket.socketpair()); } @Test(timeout = 180000) public void testClientServerOutStreamInAbb() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, ArrayBackedByteBufferReadStrategy.class, null); } @Test(timeout = 180000) public void testClientServerOutStreamInAbbWithSocketpair() throws Exception { testClientServer1(OutputStreamWriteStrategy.class, ArrayBackedByteBufferReadStrategy.class, DomainSocket.socketpair()); } static private class PassedFile { private final int idx; private final byte[] contents; private FileInputStream fis; public PassedFile(int idx) throws IOException { this.idx = idx; this.contents = new byte[] { (byte) (idx % 127) }; Files.write(contents, new File(getPath())); this.fis = new FileInputStream(getPath()); } public String getPath() { return new File(sockDir.getDir(), "passed_file" + idx).getAbsolutePath(); } public FileInputStream getInputStream() throws IOException { return fis; } public void cleanup() throws IOException { new File(getPath()).delete(); fis.close(); } public void checkInputStream(FileInputStream fis) throws IOException { byte buf[] = new byte[contents.length]; IOUtils.readFully(fis, buf, 0, buf.length); Arrays.equals(contents, buf); } protected void finalize() { try { cleanup(); } catch (Throwable t) { // ignore } } } /** * Test file descriptor passing. * * @throws IOException */ @Test(timeout = 180000) public void testFdPassing() throws Exception { final String TEST_PATH = new File(sockDir.getDir(), "test_sock").getAbsolutePath(); final byte clientMsg1[] = new byte[] { 0x11, 0x22, 0x33, 0x44, 0x55, 0x66 }; final byte serverMsg1[] = new byte[] { 0x31, 0x30, 0x32, 0x34, 0x31, 0x33, 0x44, 0x1, 0x1, 0x1, 0x1, 0x1 }; final ArrayBlockingQueue<Throwable> threadResults = new ArrayBlockingQueue<Throwable>(2); final DomainSocket serv = DomainSocket.bindAndListen(TEST_PATH); final PassedFile passedFiles[] = new PassedFile[] { new PassedFile(1), new PassedFile(2) }; final FileDescriptor passedFds[] = new FileDescriptor[passedFiles.length]; for (int i = 0; i < passedFiles.length; i++) { passedFds[i] = passedFiles[i].getInputStream().getFD(); } Thread serverThread = new Thread() { public void run() { // Run server DomainSocket conn = null; try { conn = serv.accept(); byte in1[] = new byte[clientMsg1.length]; InputStream connInputStream = conn.getInputStream(); IOUtils.readFully(connInputStream, in1, 0, in1.length); Assert.assertTrue(Arrays.equals(clientMsg1, in1)); DomainSocket domainConn = (DomainSocket) conn; domainConn.sendFileDescriptors(passedFds, serverMsg1, 0, serverMsg1.length); conn.close(); } catch (Throwable e) { threadResults.add(e); Assert.fail(e.getMessage()); } threadResults.add(new Success()); } }; serverThread.start(); Thread clientThread = new Thread() { public void run() { try { DomainSocket client = DomainSocket.connect(TEST_PATH); OutputStream clientOutputStream = client.getOutputStream(); InputStream clientInputStream = client.getInputStream(); clientOutputStream.write(clientMsg1); DomainSocket domainConn = (DomainSocket) client; byte in1[] = new byte[serverMsg1.length]; FileInputStream recvFis[] = new FileInputStream[passedFds.length]; int r = domainConn.recvFileInputStreams(recvFis, in1, 0, in1.length - 1); Assert.assertTrue(r > 0); IOUtils.readFully(clientInputStream, in1, r, in1.length - r); Assert.assertTrue(Arrays.equals(serverMsg1, in1)); for (int i = 0; i < passedFds.length; i++) { Assert.assertNotNull(recvFis[i]); passedFiles[i].checkInputStream(recvFis[i]); } for (FileInputStream fis : recvFis) { fis.close(); } client.close(); } catch (Throwable e) { threadResults.add(e); } threadResults.add(new Success()); } }; clientThread.start(); for (int i = 0; i < 2; i++) { Throwable t = threadResults.take(); if (!(t instanceof Success)) { Assert.fail(t.getMessage() + ExceptionUtils.getStackTrace(t)); } } serverThread.join(120000); clientThread.join(120000); serv.close(); for (PassedFile pf : passedFiles) { pf.cleanup(); } } /** * Run validateSocketPathSecurity * * @param str The path to validate * @param prefix A prefix to skip validation for * @throws IOException */ private static void testValidateSocketPath(String str, String prefix) throws IOException { int skipComponents = 1; File prefixFile = new File(prefix); while (true) { prefixFile = prefixFile.getParentFile(); if (prefixFile == null) { break; } skipComponents++; } DomainSocket.validateSocketPathSecurity0(str, skipComponents); } /** * Test file descriptor path security. * * @throws IOException */ @Test(timeout = 180000) public void testFdPassingPathSecurity() throws Exception { TemporarySocketDirectory tmp = new TemporarySocketDirectory(); try { String prefix = tmp.getDir().getAbsolutePath(); Shell.execCommand(new String[] { "mkdir", "-p", prefix + "/foo/bar/baz" }); Shell.execCommand(new String[] { "chmod", "0700", prefix + "/foo/bar/baz" }); Shell.execCommand(new String[] { "chmod", "0700", prefix + "/foo/bar" }); Shell.execCommand(new String[] { "chmod", "0707", prefix + "/foo" }); Shell.execCommand(new String[] { "mkdir", "-p", prefix + "/q1/q2" }); Shell.execCommand(new String[] { "chmod", "0700", prefix + "/q1" }); Shell.execCommand(new String[] { "chmod", "0700", prefix + "/q1/q2" }); testValidateSocketPath(prefix + "/q1/q2", prefix); try { testValidateSocketPath(prefix + "/foo/bar/baz", prefix); } catch (IOException e) { GenericTestUtils.assertExceptionContains("world-writable", e); GenericTestUtils.assertExceptionContains("/foo'", e); } try { testValidateSocketPath(prefix + "/nope", prefix); } catch (IOException e) { GenericTestUtils.assertExceptionContains("failed to stat a path " + "component: ", e); } // Root should be secure DomainSocket.validateSocketPathSecurity0("/foo", 0); } finally { tmp.close(); } } @Test(timeout = 180000) public void testShutdown() throws Exception { final AtomicInteger bytesRead = new AtomicInteger(0); final AtomicBoolean failed = new AtomicBoolean(false); final DomainSocket[] socks = DomainSocket.socketpair(); Runnable reader = new Runnable() { @Override public void run() { while (true) { try { int ret = socks[1].getInputStream().read(); if (ret == -1) return; bytesRead.addAndGet(1); } catch (IOException e) { DomainSocket.LOG.error("reader error", e); failed.set(true); return; } } } }; Thread readerThread = new Thread(reader); readerThread.start(); socks[0].getOutputStream().write(1); socks[0].getOutputStream().write(2); socks[0].getOutputStream().write(3); Assert.assertTrue(readerThread.isAlive()); socks[0].shutdown(); readerThread.join(); Assert.assertFalse(failed.get()); Assert.assertEquals(3, bytesRead.get()); IOUtils.cleanup(null, socks); } }