org.apache.hadoop.mapreduce.security.SecureShuffleUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.mapreduce.security.SecureShuffleUtils.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.mapreduce.security;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import javax.crypto.SecretKey;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Charsets;

/**
 * 
 * utilities for generating kyes, hashes and verifying them for shuffle
 *
 */
@InterfaceAudience.Private
@InterfaceStability.Unstable
public class SecureShuffleUtils {
    private static final Logger LOG = LoggerFactory.getLogger(SecureShuffleUtils.class);

    public static final String HTTP_HEADER_URL_HASH = "UrlHash";
    public static final String HTTP_HEADER_REPLY_URL_HASH = "ReplyHash";

    /**
     * Base64 encoded hash of msg
     * @param msg
     */
    public static String generateHash(byte[] msg, SecretKey key) {
        return new String(Base64.encodeBase64(generateByteHash(msg, key)), Charsets.UTF_8);
    }

    /**
     * calculate hash of msg
     * @param msg
     * @return
     */
    private static byte[] generateByteHash(byte[] msg, SecretKey key) {
        return JobTokenSecretManager.computeHash(msg, key);
    }

    /**
     * verify that hash equals to HMacHash(msg)
     * @param newHash
     * @return true if is the same
     */
    private static boolean verifyHash(byte[] hash, byte[] msg, SecretKey key) {
        byte[] msg_hash = generateByteHash(msg, key);
        return WritableComparator.compareBytes(msg_hash, 0, msg_hash.length, hash, 0, hash.length) == 0;
    }

    /**
     * Aux util to calculate hash of a String
     * @param enc_str
     * @param key
     * @return Base64 encodedHash
     * @throws IOException
     */
    public static String hashFromString(String enc_str, SecretKey key) throws IOException {
        return generateHash(enc_str.getBytes(Charsets.UTF_8), key);
    }

    /**
     * verify that base64Hash is same as HMacHash(msg)  
     * @param base64Hash (Base64 encoded hash)
     * @param msg
     * @throws IOException if not the same
     */
    public static void verifyReply(String base64Hash, String msg, SecretKey key) throws IOException {
        byte[] hash = Base64.decodeBase64(base64Hash.getBytes(Charsets.UTF_8));

        boolean res = verifyHash(hash, msg.getBytes(Charsets.UTF_8), key);

        if (res != true) {
            throw new IOException("Verification of the hashReply failed");
        }
    }

    /**
     * Shuffle specific utils - build string for encoding from URL
     * @param url
     * @return string for encoding
     */
    public static String buildMsgFrom(URL url) {
        return buildMsgFrom(url.getPath(), url.getQuery(), url.getPort());
    }

    /**
     * Shuffle specific utils - build string for encoding from URL
     * @param request
     * @return string for encoding
     */
    public static String buildMsgFrom(HttpServletRequest request) {
        return buildMsgFrom(request.getRequestURI(), request.getQueryString(), request.getLocalPort());
    }

    /**
     * Shuffle specific utils - build string for encoding from URL
     * @param uri_path
     * @param uri_query
     * @return string for encoding
     */
    private static String buildMsgFrom(String uri_path, String uri_query, int port) {
        return String.valueOf(port) + uri_path + "?" + uri_query;
    }

    /**
     * byte array to Hex String
     * 
     * @param ba
     * @return string with HEX value of the key
     */
    public static String toHex(byte[] ba) {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        String strHex = "";
        try {
            PrintStream ps = new PrintStream(baos, false, "UTF-8");
            for (byte b : ba) {
                ps.printf("%x", b);
            }
            strHex = baos.toString("UTF-8");
        } catch (UnsupportedEncodingException e) {
        }
        return strHex;
    }
}