com.emc.vipr.ribbon.ViPRDataServicesServerList.java Source code

Java tutorial

Introduction

Here is the source code for com.emc.vipr.ribbon.ViPRDataServicesServerList.java

Source

/*
 * Copyright 2014 EMC Corporation. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 * http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package com.emc.vipr.ribbon;

import com.emc.vipr.ribbon.bean.ListDataNode;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.AbstractServerList;
import com.netflix.loadbalancer.Server;
import org.apache.commons.codec.binary.Base64;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.conn.PoolingClientConnectionManager;
import org.apache.http.params.HttpConnectionParams;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Unmarshaller;
import java.io.IOException;
import java.io.InputStream;
import java.text.SimpleDateFormat;
import java.util.*;

public class ViPRDataServicesServerList extends AbstractServerList<Server> {
    private static final Logger logger = LoggerFactory.getLogger(ViPRDataServicesServerList.class);

    protected final SimpleDateFormat rfc822DateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z",
            Locale.US);
    protected Unmarshaller unmarshaller;

    private String protocol;
    private List<Server> nodeList;
    private int port;
    private String user;
    private String secret;
    private HttpClient httpClient;
    private int requestCounter = 0;

    public ViPRDataServicesServerList() {
        rfc822DateFormat.setTimeZone(new SimpleTimeZone(0, "GMT"));
        try {
            unmarshaller = JAXBContext.newInstance(ListDataNode.class).createUnmarshaller();
        } catch (JAXBException e) {
            throw new RuntimeException("can't create unmarshaller", e);
        }
        PoolingClientConnectionManager cm = new PoolingClientConnectionManager();
        cm.setDefaultMaxPerRoute(10);
        httpClient = new DefaultHttpClient(cm);
    }

    @Override
    public void initWithNiwsConfig(IClientConfig clientConfig) {
        protocol = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesProtocol, "")
                .toLowerCase();
        if (!Arrays.asList("http", "https").contains(protocol))
            throw new IllegalArgumentException("Invalid protocol: " + protocol);

        String nodeStr = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesInitialNodes, "");
        if (nodeStr.trim().length() == 0)
            throw new IllegalStateException("No servers configured in smartConfig or NIWS config");
        setNodeList(SmartClientConfig.parseServerList(nodeStr));

        // pull the port from the initial node list (it will not be returned by the list-data-nodes call)
        port = getNodeList().get(0).getPort();

        user = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesUser, null);

        secret = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesUserSecret, null);

        int timeout = clientConfig.getPropertyAsInteger(SmartClientConfigKey.ViPRDataServicesTimeout,
                SmartClientConfig.DEFAULT_TIMEOUT);
        HttpConnectionParams.setConnectionTimeout(httpClient.getParams(), timeout);
        HttpConnectionParams.setSoTimeout(httpClient.getParams(), timeout);

        if (logger.isDebugEnabled()) {
            logger.debug("Configured node enumeration");
            logger.debug("--- protocol: " + protocol);
            logger.debug("--- nodeList: " + getNodeList());
            logger.debug("--- user: " + user);
            logger.debug("--- secret: " + secret);
            logger.debug("--- timeout: " + timeout);
            logger.debug("--- httpClient: " + (httpClient != null));
        }
    }

    @Override
    public List<Server> getInitialListOfServers() {
        return getNodeList();
    }

    @Override
    public List<Server> getUpdatedListOfServers() {
        return pollForServers();
    }

    protected List<Server> pollForServers() {
        try {
            List<Server> activeNodeList = getNodeList();
            int activeNodeCount = activeNodeList.size();
            Server server = null;
            List<String> hosts = null;
            String path = "/?endpoint";

            // we want to try a different node on failure until we try every active node (HttpClient will auto-retry
            // 500s and some IOEs), but we don't want to start with the same node each time.
            for (int i = 0; i < activeNodeCount; i++) {
                try {
                    // get next server in the list (trying to distribute this call among active nodes)
                    // note: the extra modulus logic is there just in case requestCounter wraps around to a negative value
                    server = activeNodeList
                            .get((requestCounter++ % activeNodeCount + activeNodeCount) % activeNodeCount);

                    HttpGet request = new HttpGet(protocol + "://" + server + path);
                    logger.debug("endpoint query attempt #" + (i + 1) + ": trying " + server);

                    // format date
                    String rfcDate;
                    synchronized (rfc822DateFormat) {
                        rfcDate = rfc822DateFormat.format(new Date());
                    }

                    // generate signature
                    String canonicalString = "GET\n\n\n" + rfcDate + "\n" + path;
                    String signature = getSignature(canonicalString, secret);

                    // add date and auth headers
                    request.addHeader("Date", rfcDate);
                    request.addHeader("Authorization", "AWS " + user + ":" + signature);

                    // send request
                    HttpResponse response = httpClient.execute(request);
                    if (response.getStatusLine().getStatusCode() > 299) {
                        EntityUtils.consumeQuietly(response.getEntity());
                        throw new RuntimeException("received error response: " + response.getStatusLine());
                    }

                    logger.debug("received success response: " + response.getStatusLine());
                    hosts = parseResponse(response);
                    break;
                } catch (Exception e) {
                    logger.warn("error polling for endpoints on " + server, e);
                }
            }

            if (hosts == null)
                throw new RuntimeException("Exhausted all nodes; no response available");

            List<Server> updatedNodeList = new ArrayList<Server>();
            for (String host : hosts) {
                updatedNodeList.add(new Server(host, port));
            }
            setNodeList(updatedNodeList);
        } catch (Exception e) {
            logger.warn("Unable to poll for servers", e);
        }
        return getNodeList();
    }

    protected String getSignature(String canonicalString, String secret) throws Exception {
        Mac mac = Mac.getInstance("HmacSHA1");
        mac.init(new SecretKeySpec(secret.getBytes("UTF-8"), "HmacSHA1"));
        String signature = new String(Base64.encodeBase64(mac.doFinal(canonicalString.getBytes("UTF-8"))));
        logger.debug("canonicalString:\n" + canonicalString);
        logger.debug("signature:\n" + signature);
        return signature;
    }

    @SuppressWarnings("unchecked")
    protected List<String> parseResponse(HttpResponse response) throws IOException, JAXBException {
        InputStream contentStream = response.getEntity().getContent();
        try {
            ListDataNode listDataNode = (ListDataNode) unmarshaller.unmarshal(contentStream);

            List<String> hosts = new ArrayList<String>();
            for (String host : listDataNode.getDataNodes()) {
                hosts.add(host.trim());
            }
            return hosts;
        } finally {
            try {
                contentStream.close();
            } catch (RuntimeException e) {
                logger.warn("error closing HTTP content stream", e);
            }
        }
    }

    protected List<Server> getNodeList() {
        return nodeList;
    }

    protected synchronized void setNodeList(List<Server> nodeList) {
        this.nodeList = Collections.unmodifiableList(nodeList);
    }
}