com.teradata.presto.yarn.test.utils.NodeSshUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.teradata.presto.yarn.test.utils.NodeSshUtils.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.teradata.presto.yarn.test.utils;

import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.teradata.tempto.context.State;
import com.teradata.tempto.ssh.SshClient;
import com.teradata.tempto.ssh.SshClientFactory;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Stream;

import static com.google.common.base.CharMatcher.anyOf;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Sets.newHashSet;
import static com.teradata.presto.yarn.test.utils.TimeUtils.retryUntil;
import static java.lang.Long.parseLong;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

public class NodeSshUtils implements State {
    private static final Logger log = LoggerFactory.getLogger(NodeSshUtils.class);

    private final SshClientFactory sshClientFactory;
    private final SshClient yarnSshClient;

    public NodeSshUtils(SshClientFactory sshClientFactory, SshClient yarnSshClient) {
        this.yarnSshClient = yarnSshClient;
        this.sshClientFactory = sshClientFactory;
    }

    public boolean isPrestoProcessRunning(String host) {
        return withSshClient(host, sshClient -> {
            String prestoProcessesCountRow = sshClient.command("ps aux | grep PrestoServer | grep -v grep || true")
                    .trim();
            int processesCount = prestoProcessesCountRow.split("\n").length;
            if (StringUtils.isEmpty(prestoProcessesCountRow)) {
                processesCount = 0;
            }

            log.info("Presto processes count on {}: {}", host, processesCount);
            checkState(processesCount == 0 || processesCount == 1,
                    "Unexpected number of presto proceses: %s on: %s", processesCount, host);
            return processesCount == 1;
        });
    }

    public void killPrestoProcesses(String host) {
        runOnNode(host, singletonList("pkill -9 -f 'java.*PrestoServer.*'"));
        retryUntil(() -> !isPrestoProcessRunning(host), TimeUnit.SECONDS.toMillis(10));
    }

    public long getPrestoJvmMemory(String host) {
        return withSshClient(host, sshClient -> {
            String prestoServerPid = sshClient.command("pgrep -f PrestoServer").trim();
            long prestoProcessJvm = parseLong(sshClient.command(
                    "jmap -heap " + prestoServerPid + " | grep capacity | awk 'NR == 1' | awk '{print $3}'"));
            log.info("Presto jvm memory " + host + ": " + prestoProcessJvm);
            return prestoProcessJvm;
        });
    }

    public String getPrestoJvmProcess(String host) {
        return withSshClient(host,
                sshClient -> sshClient.command("ps aux | grep PrestoServer | grep -v grep").trim());
    }

    public String createLabels(Map<String, String> labels) {
        return commandOnYarn(
                "yarn rmadmin -addToClusterNodeLabels " + Joiner.on(",").join(newHashSet(labels.values())));
    }

    public void labelNodes(Map<String, String> labels) {
        waitForNodeManagers(labels.size());
        List<String> nodeIds = getNodeIds();

        Map<String, String> nodeToNodeIds = labels.keySet().stream().collect(toMap(node -> node,
                node -> nodeIds.stream().filter(nodeId -> nodeId.contains(node)).findFirst().get()));

        String replaceLabelsArgument = labels.keySet().stream().map(node -> node + "," + labels.get(node))
                .reduce(joinOn(" ")).get();
        commandOnYarn("yarn rmadmin -replaceLabelsOnNode \'" + replaceLabelsArgument + "\'");
        commandOnYarn("yarn rmadmin -refreshQueues");

        checkThatLabelsAreSetCorrectly(labels, nodeToNodeIds);
    }

    private static BinaryOperator<String> joinOn(String separator) {
        return (first, second) -> first + separator + second;
    }

    private void checkThatLabelsAreSetCorrectly(Map<String, String> labels, Map<String, String> nodeToNodeIds) {
        String clusterNodeLabels = commandOnYarn("yarn queue -status default | grep 'Accessible Node Labels'");
        labels.values().forEach(label -> checkState(clusterNodeLabels.contains(label),
                "Cluster node labels '{}', does not contain label '{}'", clusterNodeLabels, label));
        labels.entrySet().stream().forEach(entry -> {
            String node = entry.getKey();
            String label = entry.getValue();
            String nodeLabels = commandOnYarn(
                    "yarn node -status " + nodeToNodeIds.get(node) + " | grep \'Node-Labels\'");
            checkState(nodeLabels.contains(label), "Node labels '{}' on node '{}' does not contain label '{}'",
                    nodeLabels, node, label);
        });
    }

    private void waitForNodeManagers(int numberOfNodes) {
        log.info("Waiting for NodeManagers...");
        retryUntil(() -> getNodeIds().size() >= numberOfNodes, MINUTES.toMillis(2));
    }

    public List<String> getNodeIds() {
        return Stream.of(commandOnYarn("yarn node -list").split("\n")).filter(line -> line.contains("RUNNING")).map(
                line -> Splitter.on(anyOf(" \t")).omitEmptyStrings().trimResults().split(line).iterator().next())
                .collect(toList());
    }

    public String commandOnYarn(String command) {
        return yarnSshClient.command("source /etc/profile && " + command).trim();
    }

    public void runOnNode(String node, String command) {
        runOnNode(node, singletonList(command));
    }

    public List<String> runOnNode(String node, List<String> commands) {
        return withSshClient(node, sshClient -> {
            return commands.stream().map(sshClient::command).collect(toList());
        });
    }

    public <T> List<T> withSshClient(Collection<String> hosts, Function<SshClient, T> closure) {
        return hosts.stream().map(host -> withSshClient(host, closure)).collect(toList());
    }

    public <T> T withSshClient(String host, Function<SshClient, T> function) {
        try (SshClient sshClient = sshClientFactory.create(host)) {
            return function.apply(sshClient);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Optional<String> getName() {
        return Optional.empty();
    }
}