Java tutorial
/** * Copyright (C) 2011 Ovea <dev@ovea.com> * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.ovea.mongodb; import com.mongodb.DB; import com.mongodb.Mongo; import com.mycila.xmltool.CallBack; import com.mycila.xmltool.XMLDoc; import com.mycila.xmltool.XMLTag; import com.ovea.system.pipe.Pipes; import com.ovea.system.proc.FutureProcess; import com.ovea.system.util.IoUtils; import com.ovea.system.util.NetUtils; import com.ovea.system.util.ProcUtils; import org.bson.BSONObject; import org.hyperic.sigar.SigarLoader; import java.io.*; import java.net.UnknownHostException; import java.nio.channels.FileChannel; import java.nio.channels.FileLock; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; /** * @author Mathieu Carbou (mathieu.carbou@gmail.com) */ public abstract class EmbeddedMongoDB { private static final Lock LOCK = new Lock(); private final AtomicBoolean closed = new AtomicBoolean(); private final int port; private final File dbPath; EmbeddedMongoDB(int port, File dbPath) { this.port = port; this.dbPath = dbPath; } public final File dbPath() { return dbPath; } public final int port() { return port; } public final void terminate() { if (closed.compareAndSet(false, true)) { LOCK.lock(); try { remove(this); onTerminate(); } finally { LOCK.unlock(); } } } public final boolean isTerminated() { return NetUtils.isPortAvailable(port()); } @Override public final String toString() { return "MongoDB(port=" + port() + (dbPath == null ? "" : ", db=" + dbPath()) + ")"; } @Override public final boolean equals(Object o) { if (this == o) return true; if (!(o instanceof EmbeddedMongoDB)) return false; EmbeddedMongoDB that = (EmbeddedMongoDB) o; return port() == that.port(); } @Override public final int hashCode() { return port(); } abstract void onTerminate(); public abstract void waitFor() throws InterruptedException; public abstract long pid(); // builders public static EmbeddedMongoDB create(String mongodExecutable) { if (mongodExecutable == null) throw new IllegalArgumentException("Missing MongoDB executable path"); return create(mongodExecutable, "--nohttpinterface", "--nojournal", "--noauth"); } public static EmbeddedMongoDB create(String mongodExecutable, String... args) { if (mongodExecutable == null) throw new IllegalArgumentException("Missing MongoDB executable path"); return add(new ChildProcess(mongodExecutable, args)); } public static EmbeddedMongoDB addRunning(int port) { return addRunning(port, null, null); } public static EmbeddedMongoDB addRunning(int port, String username, String password) { try { Mongo mongo = null; File path = null; try { mongo = new Mongo("localhost", port); DB db = mongo.getDB("admin"); if (username != null && password != null) db.authenticate(username, password.toCharArray()); BSONObject obj = (BSONObject) db.command("getCmdLineOpts").get("parsed"); if (obj != null) { String val = (String) obj.get("dbpath"); if (val != null) path = new File(val); } } finally { if (mongo != null) mongo.close(); } return add(new FileEntry(port, -1, path)); } catch (UnknownHostException e) { throw new RuntimeException(e.getMessage(), e); } } public static List<EmbeddedMongoDB> list() { LOCK.lock(); try { File file = new File(System.getProperty("java.io.tmpdir"), EmbeddedMongoDB.class.getSimpleName() + ".xml"); if (!file.exists()) { return new ArrayList<EmbeddedMongoDB>(0); } XMLTag tag = XMLDoc.from(file, true); final List<EmbeddedMongoDB> list = new ArrayList<EmbeddedMongoDB>(5); tag.forEach("instance", new CallBack() { @Override public void execute(XMLTag doc) { long pid = Long.parseLong(doc.getAttribute("pid")); int port = Integer.parseInt(doc.getAttribute("port")); String db = doc.findAttribute("db"); EmbeddedMongoDB embedded = new FileEntry(port, pid, db == null ? null : new File(db)); if (!embedded.isTerminated()) { list.add(embedded); } } }); return list; } finally { LOCK.unlock(); } } private static void save(List<EmbeddedMongoDB> list) { LOCK.lock(); try { File file = new File(System.getProperty("java.io.tmpdir"), EmbeddedMongoDB.class.getSimpleName() + ".xml"); XMLTag tag = XMLDoc.newDocument(true).addRoot("mongod"); for (EmbeddedMongoDB mongoDB : list) { tag.addTag("instance").addAttribute("pid", "" + mongoDB.pid()).addAttribute("port", "" + mongoDB.port()); if (mongoDB.dbPath != null) tag.addAttribute("db", mongoDB.dbPath.getAbsolutePath()); } Writer w = new FileWriter(file); tag.gotoRoot().toStream(w); w.close(); } catch (IOException e) { throw new RuntimeException(e.getMessage(), e); } finally { LOCK.unlock(); } } public static EmbeddedMongoDB getOrCreate(String mongodExecutable, String... args) { LOCK.lock(); try { List<EmbeddedMongoDB> existing = list(); return existing.isEmpty() ? add(create(mongodExecutable, args)) : existing.get(0); } finally { LOCK.unlock(); } } private static EmbeddedMongoDB add(EmbeddedMongoDB db) { LOCK.lock(); try { List<EmbeddedMongoDB> existing = list(); for (EmbeddedMongoDB exist : existing) { if (exist.equals(db)) { return exist; } } existing.add(db); save(existing); return db; } finally { LOCK.unlock(); } } private static void remove(EmbeddedMongoDB db) { LOCK.lock(); try { List<EmbeddedMongoDB> existing = list(); Iterator<EmbeddedMongoDB> it = existing.iterator(); while (it.hasNext()) { EmbeddedMongoDB cur = it.next(); if (cur.equals(db)) { it.remove(); } } save(existing); } finally { LOCK.unlock(); } } private static final class ChildProcess extends EmbeddedMongoDB { private final FutureProcess process; private final File pidFile; private Long pid; private ChildProcess(String mongodExecutable, String... args) { super(findPort(args), findDbPath(args)); dbPath().mkdirs(); try { List<String> commands = new ArrayList<String>(args.length + 3); if (SigarLoader.IS_WIN32) { commands.add("cmd.exe"); commands.add("/c"); } else { commands.add("sh"); commands.add("-c"); } StringBuilder sb = mongodExecutable.contains(" ") ? new StringBuilder("\"").append(mongodExecutable).append("\"") : new StringBuilder(mongodExecutable); for (String arg : args) { sb.append(" ").append(arg); } List<String> argList = Arrays.asList(args); if (!argList.contains("--dbpath")) { if (dbPath().getAbsolutePath().contains(" ")) { sb.append(" ").append("--dbpath").append("\"").append(dbPath().getAbsolutePath()) .append("\""); } else { sb.append(" ").append("--dbpath ").append(dbPath().getAbsolutePath()); } } if (!argList.contains("--port")) { sb.append(" ").append("--port ").append(port()); } int p = argList.indexOf("--pidfilepath"); if (p == -1) { pidFile = new File(dbPath(), "mongod-" + port() + ".pid"); if (pidFile.getAbsolutePath().contains(" ")) { sb.append(" ").append("--pidfilepath").append(" \"").append(pidFile.getAbsolutePath()) .append("\""); } else { sb.append(" ").append("--pidfilepath ").append(pidFile.getAbsolutePath()); } } else { pidFile = new File(argList.get(p + 1)); } commands.add(sb.toString()); Process mongod = new ProcessBuilder(commands).start(); Pipes.connect("out", mongod.getInputStream(), IoUtils.uncloseable(System.out)); Pipes.connect("err", mongod.getErrorStream(), IoUtils.uncloseable(System.err)); process = new FutureProcess(mongod); while (!Thread.currentThread().isInterrupted() && !process.isDone() && (!pidFile.exists() || !NetUtils.canConnect("localhost", port()))) { try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); process.cancel(true); break; } } } catch (IOException e) { throw new RuntimeException(e.getMessage(), e); } } @Override void onTerminate() { long pid = pid(); if (pid != -1) { ProcUtils.terminate(pid); for (int loop = 1; loop <= 20 && !Thread.currentThread().isInterrupted() && !NetUtils.isPortAvailable(port()); loop++) { try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); break; } } if (!NetUtils.isPortAvailable(port())) { ProcUtils.kill(pid); } } pidFile.delete(); } @Override public long pid() { if (pid == null) { Scanner scanner = null; try { scanner = new Scanner(pidFile); pid = scanner.nextLong(); } catch (FileNotFoundException e) { return -1; } finally { IoUtils.close(scanner); } } return pid; } @Override public void waitFor() throws InterruptedException { try { process.get(); } catch (ExecutionException e) { throw new RuntimeException(e.getCause().getMessage(), e.getCause()); } } } private static File findDbPath(String... args) { List<String> a = Arrays.asList(args); int pos = a.indexOf("--dbpath"); if (pos == -1) { return new File(System.getProperty("java.io.tmpdir"), "mongodb"); } else { return new File(a.get(pos + 1)); } } private static int findPort(String... args) { List<String> a = Arrays.asList(args); int pos = a.indexOf("--port"); if (pos == -1) { return NetUtils.findAvailablePort(); } else { return Integer.parseInt(a.get(pos + 1)); } } private static final class FileEntry extends EmbeddedMongoDB { private final long pid; private FileEntry(int port, long pid, File dbPath) { super(port, dbPath); this.pid = pid; } @Override void onTerminate() { if (pid != -1) { ProcUtils.terminate(pid); for (int loop = 1; loop <= 20 && !Thread.currentThread().isInterrupted() && !NetUtils.isPortAvailable(port()); loop++) { try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); break; } } if (!NetUtils.isPortAvailable(port())) { ProcUtils.kill(pid); } } } @Override public void waitFor() { throw new UnsupportedOperationException(); } @Override public long pid() { return pid; } } private static final class Lock { private final File lockFile = new File(System.getProperty("java.io.tmpdir"), EmbeddedMongoDB.class.getSimpleName() + ".lock"); private final AtomicLong count = new AtomicLong(); private FileChannel channel; private FileLock lock; public void lock() { if (count.getAndIncrement() == 0) { try { channel = new RandomAccessFile(lockFile, "rw").getChannel(); lock = channel.lock(); } catch (IOException e) { throw new RuntimeException(e.getMessage(), e); } } } public void unlock() { if (count.decrementAndGet() == 0) { try { lock.release(); lock = null; } catch (IOException ignored) { } try { channel.close(); channel = null; } catch (IOException ignored) { } lockFile.delete(); } } } }