io.hops.tensorflow.TestCluster.java Source code

Java tutorial

Introduction

Here is the source code for io.hops.tensorflow.TestCluster.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.hops.tensorflow;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileContext;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.JarFinder;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.server.MiniYARNCluster;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacityScheduler;
import org.junit.After;
import org.junit.Before;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;

public abstract class TestCluster {

    private static final Log LOG = LogFactory.getLog(TestCluster.class);

    protected MiniYARNCluster yarnCluster = null;
    protected YarnConfiguration conf = null;
    private static final int NUM_NMS = 1;

    protected final static String APPMASTER_JAR = JarFinder.getJar(ApplicationMaster.class);

    @Before
    public void setup() throws Exception {
        setupInternal(NUM_NMS);
    }

    protected void setupInternal(int numNodeManager) throws Exception {

        LOG.info("Starting up YARN cluster");

        conf = new YarnConfiguration();
        conf.setInt(YarnConfiguration.RM_SCHEDULER_MINIMUM_ALLOCATION_MB, 128);
        conf.set("yarn.log.dir", "target");
        conf.set("yarn.log-aggregation-enable", "true");
        conf.setBoolean(YarnConfiguration.TIMELINE_SERVICE_ENABLED, true);
        conf.set(YarnConfiguration.RM_SCHEDULER, CapacityScheduler.class.getName());
        conf.setBoolean(YarnConfiguration.NODE_LABELS_ENABLED, true);
        conf.setBoolean(YarnConfiguration.NM_GPU_RESOURCE_ENABLED, false);

        if (yarnCluster == null) {
            yarnCluster = new MiniYARNCluster(TestCluster.class.getSimpleName(), 1, numNodeManager, 1, 1);
            yarnCluster.init(conf);

            yarnCluster.start();

            conf.set(YarnConfiguration.TIMELINE_SERVICE_WEBAPP_ADDRESS,
                    MiniYARNCluster.getHostname() + ":" + yarnCluster.getApplicationHistoryServer().getPort());

            waitForNMsToRegister();

            URL url = Thread.currentThread().getContextClassLoader().getResource("yarn-site.xml");
            if (url == null) {
                throw new RuntimeException("Could not find 'yarn-site.xml' dummy file in classpath");
            }
            Configuration yarnClusterConfig = yarnCluster.getConfig();
            yarnClusterConfig.set("yarn.application.classpath", new File(url.getPath()).getParent());
            //write the document to a buffer (not directly to the file, as that
            //can cause the file being written to get read -which will then fail.
            ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
            yarnClusterConfig.writeXml(bytesOut);
            bytesOut.close();
            //write the bytes to the file in the classpath
            OutputStream os = new FileOutputStream(new File(url.getPath()));
            os.write(bytesOut.toByteArray());
            os.close();
        }
        FileContext fsContext = FileContext.getLocalFSFileContext();
        fsContext.delete(new Path(conf.get("yarn.timeline-service.leveldb-timeline-store.path")), true);
        try {
            Thread.sleep(2000);
        } catch (InterruptedException e) {
            LOG.info("setup thread sleep interrupted. message=" + e.getMessage());
        }
    }

    @After
    public void tearDown() throws IOException {
        if (yarnCluster != null) {
            try {
                yarnCluster.stop();
            } finally {
                yarnCluster = null;
            }
        }
        FileContext fsContext = FileContext.getLocalFSFileContext();
        fsContext.delete(new Path(conf.get("yarn.timeline-service.leveldb-timeline-store.path")), true);
    }

    protected void waitForNMsToRegister() throws Exception {
        int sec = 60;
        while (sec >= 0) {
            if (yarnCluster.getResourceManager().getRMContext().getRMNodes().size() >= NUM_NMS) {
                break;
            }
            Thread.sleep(1000);
            sec--;
        }
    }
}