org.kaaproject.kaa.server.transports.http.transport.HttpTestClient.java Source code

Java tutorial

Introduction

Here is the source code for org.kaaproject.kaa.server.transports.http.transport.HttpTestClient.java

Source

/*
 * Copyright 2014-2016 CyberVision, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.kaaproject.kaa.server.transports.http.transport;

import org.apache.avro.specific.SpecificRecordBase;
import org.apache.commons.codec.binary.Base64;
import org.kaaproject.kaa.common.avro.AvroByteArrayConverter;
import org.kaaproject.kaa.common.endpoint.CommonEpConstans;
import org.kaaproject.kaa.common.endpoint.security.MessageEncoderDecoder;
import org.kaaproject.kaa.common.hash.EndpointObjectHash;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Vector;

/**
 * Abstract HTTP Test Client Class.
 *
 * @author Andrey Panasenko <apanasenko@cybervisiontech.com>
 */
abstract public class HttpTestClient<T extends SpecificRecordBase, R extends SpecificRecordBase>
        implements Runnable {

    /**
     * The Constant logger.
     */
    protected static final Logger logger = LoggerFactory.getLogger(HttpTestClient.class);
    /**
     * Random generator
     */
    protected static Random rnd = new Random();
    /**
     * Destination URL connection
     */
    private HttpURLConnection connection;
    /**
     * Multipart objects container
     */
    private MultipartObjects objects;

    /**
     * Test ID, random generated
     */
    private int testId;

    /**
     * byte array for signature
     */
    private byte[] signature;

    /**
     * byte array for encrypted SessionKey
     */
    private byte[] key;

    /**
     * byte array for POST Data
     */
    private byte[] data;

    /**
     * encoder/decoder
     */
    private MessageEncoderDecoder crypt;

    /**
     * Client Private Key
     */
    private PrivateKey clientPrivateKey;

    /**
     * Client Public Key
     */
    private PublicKey clientPublicKey;

    /**
     * Client Public Key Hash
     */
    private EndpointObjectHash clientPublicKeyHash;

    /**
     * AVRO request converter
     */
    private AvroByteArrayConverter<T> requestConverter;

    /**
     * AVRO response converter
     */
    private AvroByteArrayConverter<R> responseConverter;

    /**
     * Activity interface
     */
    private HttpActivity<R> activity;

    /**
     * generated test SyncRequest
     */
    private T request;

    /**
     * Constructor.
     *
     * @param serverPublicKey - server public key
     * @param commandName     - command name, used as end of URL
     * @param activity        - Activity interface implementation.
     * @throws MalformedURLException - throws if URL is incorrect
     * @throws Exception             - throws if request creation failed
     */
    public HttpTestClient(PublicKey serverPublicKey, String commandName, HttpActivity<R> activity)
            throws MalformedURLException, Exception {
        testId = rnd.nextInt();
        this.activity = activity;
        //TODO: replace
        int bindPort = 7888;
        String url = "http://localhost:" + bindPort + "/domain/" + commandName;
        connection = (HttpURLConnection) new URL(url).openConnection();
        objects = new MultipartObjects();
        requestConverter = new AvroByteArrayConverter<>(getRequestConverterClass());
        responseConverter = new AvroByteArrayConverter<>(getResponseConverterClass());
        init(serverPublicKey);
    }

    /**
     * Generate String with random ascii symbols from 48 till 122 with length size.
     *
     * @param size of String
     * @return String with random ascii symbols
     */
    public static String getRandomString(int size) {
        return MultipartObjects.getRandomString(size);
    }

    /**
     * generate random bytes array with size
     *
     * @param size of bytes
     * @return byte[] array of random bytes
     */
    public static byte[] getRandomBytes(int size) {
        byte[] rndbytes = new byte[size];
        rnd.nextBytes(rndbytes);
        return rndbytes;
    }

    /**
     * Initialization of request keys and encoder/decoder
     *
     * @param serverPublicKey - server public key
     * @throws Exception - if key generation failed.
     */
    private void init(PublicKey serverPublicKey) throws Exception {
        KeyPairGenerator clientKeyGen;
        try {
            clientKeyGen = KeyPairGenerator.getInstance("RSA");
            clientKeyGen.initialize(2048);
            KeyPair clientKeyPair = clientKeyGen.genKeyPair();
            clientPrivateKey = clientKeyPair.getPrivate();
            clientPublicKey = clientKeyPair.getPublic();
        } catch (NoSuchAlgorithmException e) {
            throw new Exception(e.toString());
        }
        crypt = new MessageEncoderDecoder(clientPrivateKey, clientPublicKey, serverPublicKey);
        try {
            key = crypt.getEncodedSessionKey();
        } catch (GeneralSecurityException e) {
            throw new Exception(e.toString());
        }

        ByteBuffer publicKeyBuffer = ByteBuffer
                .wrap(EndpointObjectHash.fromSha1(clientPublicKey.getEncoded()).getData());

        clientPublicKeyHash = EndpointObjectHash.fromBytes(publicKeyBuffer.array());

    }

    /**
     * Post initialization, encrypt and sign request
     *
     * @param request - request to encrypt and sign
     * @throws Exception - in case of encrypt error
     */
    protected void postInit(T request) throws Exception {

        try {
            byte[] requestBodyRaw = requestConverter.toByteArray(request);
            data = crypt.encodeData(requestBodyRaw);
            signature = crypt.sign(data);
            if (signature.length > 256) {
                throw new Exception("Error signature length must not be more than 256, but " + signature.length);
            }
        } catch (IOException | GeneralSecurityException e) {
            throw new Exception(e.toString());
        }

        objects.addObject(CommonEpConstans.REQUEST_SIGNATURE_ATTR_NAME, signature);
        objects.addObject(CommonEpConstans.REQUEST_KEY_ATTR_NAME, key);
        objects.addObject(CommonEpConstans.REQUEST_DATA_ATTR_NAME, data);
    }

    /* (non-Javadoc)
     * @see java.lang.Runnable#run()
     */
    @Override
    public void run() {
        logger.trace("Test: " + testId + " started...");
        IOException error = null;
        try {
            //connection.setChunkedStreamingMode(2048);
            connection.setRequestMethod("POST");
            connection.setDoOutput(true);
            connection.setRequestProperty("Content-Type", objects.getContentType());

            DataOutputStream out = new DataOutputStream(connection.getOutputStream());
            objects.dumbObjects(out);
            out.flush();
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
            error = e;
        }
        List<Byte> bodyArray = new Vector<>();

        try {
            DataInputStream r = new DataInputStream(connection.getInputStream());
            while (true) {
                bodyArray.add(new Byte(r.readByte()));
            }
        } catch (EOFException eof) {

        } catch (IOException e) {
            e.printStackTrace();
            error = e;
        }
        byte[] body = new byte[bodyArray.size()];
        for (int i = 0; i < body.length; i++) {
            body[i] = bodyArray.get(i);
        }
        processComplete(error, connection.getHeaderFields(), body);
    }

    /**
     * push Response to client invocation code
     *
     * @param e      - set if error received during HTTP request processing
     * @param header - header list
     * @param body   - body byte array
     */
    private void processComplete(IOException e, Map<String, List<String>> header, byte[] body) {
        if (e != null) {
            e.printStackTrace();
            activity.httpRequestComplete(e, this.testId, null);
            return;
        }
        try {
            R response = decodeHttpResponse(header, body);
            activity.httpRequestComplete(null, this.testId, response);
        } catch (Exception e1) {
            e1.printStackTrace();
            activity.httpRequestComplete(e1, this.testId, null);
        }

    }

    /**
     * Decode http response to Response
     *
     * @return type R Response
     */
    protected R decodeHttpResponse(Map<String, List<String>> header, byte[] body) throws Exception {
        if (header.containsKey(CommonEpConstans.SIGNATURE_HEADER_NAME)
                && header.get(CommonEpConstans.SIGNATURE_HEADER_NAME) != null
                && header.get(CommonEpConstans.SIGNATURE_HEADER_NAME).size() > 0) {
            String sigHeader = header.get(CommonEpConstans.SIGNATURE_HEADER_NAME).get(0);
            byte[] respSignature = Base64.decodeBase64(sigHeader);
            byte[] respData = body;
            crypt.verify(respData, respSignature);
            logger.trace("Test " + getId() + " response verified, body size " + body.length);
            byte[] respDecoded = crypt.decodeData(respData);
            return responseConverter.fromByteArray(respDecoded);
        } else {
            throw new Exception(
                    "HTTP response incorrect, no signature fields " + CommonEpConstans.SIGNATURE_HEADER_NAME);
        }
    }

    /**
     * Test ID getter.
     *
     * @return int Test ID
     */
    public int getId() {
        return testId;
    }

    /**
     * Client Public Key getter.
     *
     * @return the clientPublicKey
     */
    public PublicKey getClientPublicKey() {
        return clientPublicKey;
    }

    /**
     * Client Public Key Hash getter.
     *
     * @return the clientPublicKeyHash
     */
    public EndpointObjectHash getClientPublicKeyHash() {
        return clientPublicKeyHash;
    }

    /**
     * @return the request
     */
    public T getRequest() {
        return request;
    }

    /**
     *
     * @param request
     */
    public void setRequest(T request) {
        this.request = request;
    }

    /**
     * Gets the request converter class.
     *
     * @return the request converter class
     */
    protected abstract Class<T> getRequestConverterClass();

    /**
     * Gets the response converter class.
     *
     * @return the response converter class
     */
    protected abstract Class<R> getResponseConverterClass();
}