net.java.jless.tls.Record.java Source code

Java tutorial

Introduction

Here is the source code for net.java.jless.tls.Record.java

Source

/*
 * Copyright 2001-2011 Joel Hockey (joel.hockey@gmail.com).  All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

package net.java.jless.tls;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Formats a TLS Record
 *   @author      Joel Hockey
 */
public class Record {
    private static Log log;

    static {
        try {
            // use commons logging if available, else no logging
            log = LogFactory.getLog(Record.class);
        } catch (Throwable t) {
        } // ignore
    }

    // Class (static) variables.
    public static final byte ALERT_CLOSE_NOTIFY = 0;
    public static final byte ALERT_WARNING = 1;
    public static final byte ALERT_FATAL = 2;

    public static final byte CONTENTTYPE_CHANGE_CIPHER_SPEC = 20;
    public static final byte CONTENTTYPE_ALERT = 21;
    public static final byte CONTENTTYPE_HANDSHAKE = 22;
    public static final byte CONTENTTYPE_APPLICATION_DATA = 23;

    private static final int MAX_FRAGMENT_LENGTH = 491; // I don't know why?

    // state of reading and writing
    private boolean clientWriteCipherIsNull = true;
    private long clientWriteSeqNum = 0;
    private boolean serverWriteCipherIsNull = true;
    private long serverWriteSeqNum = 0;

    // current state of session
    private TLSSocket tls;
    private HMAC hmacClientWrite;
    private HMAC hmacServerWrite;
    private Cipher encryptCipher;
    private Cipher decryptCipher;
    public OutputStream outs;
    private InputStream ins;
    private int macSize;
    private int blockSize;

    // 18437 is max size of TLS record
    // always try to read an extra 5 bytes to determine if
    // another record fragment is ready
    private byte[] readBuf = new byte[18442];
    private int readBufOffset;

    /**
     * Class constructor.
     */
    public Record(TLSSocket tls) {
        this.tls = tls;
    }

    /**
     * Send a message to the server. One or more records will be written
     * depending on the size of the message.
     *
     * @param contentType   The content type of the message.  Must be valid.
     * @param msg   The message(s) to send
     */
    public void sendMessage(byte contentType, byte[] msg) throws TLSException {
        try {
            // the fragment of the message that gets written each time.
            byte[] fragment = null;

            int msgBytesSent = 0;
            int msgBytesToSend = 0;
            int msgBytesRemaining = msg.length;

            byte[] length = { 0, 0 };

            // record lengths must be less than MAX_FRAGMENT_LENGTH.
            // We may have to send mulitiple records
            for (;;) {
                if (msgBytesRemaining == 0) {
                    break;
                }

                msgBytesToSend = msgBytesRemaining > MAX_FRAGMENT_LENGTH ? MAX_FRAGMENT_LENGTH : msgBytesRemaining;

                // encrypt if required
                if (!clientWriteCipherIsNull) {
                    byte[] seqNum = l2ba(clientWriteSeqNum++);
                    byte[] mac = getMAC(hmacClientWrite, seqNum, contentType, msg, msgBytesSent, msgBytesToSend);
                    int paddingLen = blockSize == 0 ? 0 : blockSize - ((msgBytesToSend + mac.length) % blockSize);
                    byte[] messageMacPad = new byte[msgBytesToSend + mac.length + paddingLen];
                    System.arraycopy(msg, msgBytesSent, messageMacPad, 0, msgBytesToSend);
                    System.arraycopy(mac, 0, messageMacPad, msgBytesToSend, mac.length);
                    // put padding
                    for (int i = 0; i < paddingLen; i++) {
                        messageMacPad[messageMacPad.length - 1 - i] = (byte) (paddingLen - 1);
                    }
                    try {
                        log("encrypt input: " + Hex.b2s(messageMacPad));
                        encryptCipher.update(messageMacPad, 0, messageMacPad.length, messageMacPad);
                    } catch (Exception e) {
                        throw new TLSException("encrypt error: " + e.getMessage());
                    }
                    fragment = messageMacPad;

                } else {
                    fragment = new byte[msgBytesToSend];
                    System.arraycopy(msg, msgBytesSent, fragment, 0, msgBytesToSend);
                }
                length[0] = (byte) (fragment.length >> 8);
                length[1] = (byte) fragment.length;

                /* send everything in correct order */
                outs.write(contentType);
                outs.write(TLSSocket.PROTOCOL_VERSION);
                outs.write(length);
                outs.write(fragment);
                outs.flush();

                msgBytesSent += msgBytesToSend;
                msgBytesRemaining -= msgBytesToSend;
            }
        } catch (Exception e) {
            e.printStackTrace();
            throw new TLSException("Error sending Record: " + e.getMessage());
        }
    }

    /**
     * Returns the fragment contained in a single Record.
     *
     * The fragment is not necessarily a message.  It may be only part of a
     * message or may be multiple messages.  In most implementations though,
     * it will be one single message.  Returns null if the connection is
     * closed or an error occurs.
     *
     * @return  fragment
     */
    public byte[] readRecord() throws TLSException {
        int recordLength = 0;
        try {
            // read header if required
            while (readBufOffset < 5) { // header is 5 bytes
                int len = ins.read(readBuf, readBufOffset, 5 - readBufOffset);
                if (len == -1) { // no more data to read
                    tls.setConnected(false);
                    return null;
                }
                readBufOffset += len;
            }

            // check ProtocolVersion
            if (readBuf[1] != TLSSocket.PROTOCOL_VERSION[0] || readBuf[2] != TLSSocket.PROTOCOL_VERSION[1]) {
                log("Bad Protocol Version in Record Header 0x" + Hex.b2s(readBuf, 0, 5));
                throw new TLSException("Bad Protocol Version in Record Header 0x" + Hex.b2s(readBuf, 0, 5));
            }

            // get the length
            recordLength = (readBuf[3] & 0xFF) << 8 | (readBuf[4] & 0xFF);

            // read the rest
            while (readBufOffset < recordLength + 5) {
                // try to read an extra 5 bytes here to see if more fragments ready
                int len = ins.read(readBuf, readBufOffset, recordLength + 10 - readBufOffset);
                if (len < 0) {
                    throw new TLSException("Bad Record Received");
                }
                readBufOffset += len;
            }
            log("record read: (" + recordLength + ") " + Hex.b2s(readBuf, 0, 5 + recordLength));
        } catch (TLSException tlse) {
            throw tlse;
        } catch (SocketException e) { // connection closed
            e.printStackTrace();
            tls.setConnected(false);
            return null;
        } catch (IOException e) {
            e.printStackTrace();
            throw new TLSException("Error in Record.readRecord()");
        }

        byte[] fragment = new byte[recordLength];

        // decrypt if !ServerWriteCipherIsNull
        if (!serverWriteCipherIsNull) {
            try {
                decryptCipher.update(readBuf, 5, recordLength, fragment);
            } catch (Exception e) {
            }

            int fragmentLength = recordLength - macSize;
            // subtract padding from fragmentLength
            if (blockSize > 0) {
                fragmentLength -= ((fragment[recordLength - 1] & 0xff) + 1);
            }

            byte[] seqNum = l2ba(serverWriteSeqNum++);
            byte[] mac = getMAC(hmacServerWrite, seqNum, readBuf[0], fragment, 0, fragmentLength);
            log("expected mac: " + Hex.b2s(mac));

            for (int i = 0; i < mac.length; i++) {
                if (fragment[fragmentLength + i] != mac[i]) {
                    log("Bad MAC received:\ndecrypted fragment with pad: " + Hex.b2s(fragment, 0, recordLength));
                    throw new TLSException(
                            "Bad MAC received: decrypted fragment: " + Hex.b2s(fragment, 0, recordLength));
                }
            }

            log("mac good!");
            byte[] fragmentNoMac = new byte[fragmentLength];
            System.arraycopy(fragment, 0, fragmentNoMac, 0, fragmentLength);
            fragment = fragmentNoMac;

        } else {
            System.arraycopy(readBuf, 5, fragment, 0, recordLength);
        }

        // check ContentType
        if (readBuf[0] == CONTENTTYPE_ALERT) {
            if (fragment.length != 2) {
                throw new TLSException("Badly formed Alert message received");
            }

            if (fragment[1] != ALERT_CLOSE_NOTIFY) {
                throw new TLSException("Unsupported Alert received : 0x" + Hex.b2s(fragment));
            }

            sendMessage(CONTENTTYPE_ALERT, new byte[] { ALERT_WARNING, ALERT_CLOSE_NOTIFY });
            tls.setConnected(false);
            return null;
        }

        // reset m_readBufOffset
        // copy any extra data (like the next 5 bytes of a record header) to front of m_readBuf
        if (readBufOffset > recordLength + 5) {
            System.arraycopy(readBuf, recordLength + 5, readBuf, 0, readBufOffset - (recordLength + 5));
        }
        readBufOffset -= (recordLength + 5);

        return fragment;
    }

    /**
     * Promote the pending write state to be the current state
     */
    public void changeClientWriteState() {
        clientWriteCipherIsNull = false;
    }

    /**
     * Promote the pending read state to be the current state
     */
    public void changeServerWriteState() {
        serverWriteCipherIsNull = false;
    }

    /**
     * Sets the key block for the pending state.
     *
     * @param keyBlock  enough material to set all keys
     */
    public void setKeyBlock(int cipherSuite, byte[] keyBlock) {
        try {
            // assume TLS_RSA_WITH_RC4_128_MD5
            macSize = 16;
            blockSize = 0;
            int keySize = 16;
            int ivSize = 0;
            String keyAlg = "RC4";
            String cipherAlg = "RC4";
            String macAlg = "MD5";

            if (cipherSuite == TLSSocket.TLS_RSA_WITH_3DES_EDE_CBC_SHA) {
                cipherAlg = "DESede/CBC/NoPadding";
                keyAlg = "DESede";
                macAlg = "SHA-1";
                macSize = 20;
                blockSize = 8;
                keySize = 24;
                ivSize = 8;
            } else if (cipherSuite == TLSSocket.TLS_RSA_WITH_AES_128_CBC_SHA) {
                cipherAlg = "AES/CBC/NoPadding";
                keyAlg = "AES";
                macAlg = "SHA-1";
                macSize = 20;
                blockSize = 16;
                keySize = 16;
                ivSize = 16;
            }

            byte[] clientWriteMACSecret = sub(keyBlock, 0, macSize);
            byte[] serverWriteMACSecret = sub(keyBlock, macSize, macSize);
            byte[] clientWriteKey = sub(keyBlock, 2 * macSize, keySize);
            byte[] serverWriteKey = sub(keyBlock, 2 * macSize + keySize, keySize);
            byte[] clientWriteIV = sub(keyBlock, 2 * (macSize + keySize), ivSize);
            byte[] serverWriteIV = sub(keyBlock, 2 * (macSize + keySize) + ivSize, ivSize);

            hmacClientWrite = new HMAC(MessageDigest.getInstance(macAlg), clientWriteMACSecret);
            hmacServerWrite = new HMAC(MessageDigest.getInstance(macAlg), serverWriteMACSecret);

            encryptCipher = Cipher.getInstance(cipherAlg);
            decryptCipher = Cipher.getInstance(cipherAlg);

            // no IV for RC4
            if (cipherSuite == TLSSocket.TLS_RSA_WITH_RC4_128_MD5) {
                encryptCipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(clientWriteKey, keyAlg));
                decryptCipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(serverWriteKey, keyAlg));
            } else {
                encryptCipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(clientWriteKey, keyAlg),
                        new IvParameterSpec(clientWriteIV));
                decryptCipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(serverWriteKey, keyAlg),
                        new IvParameterSpec(serverWriteIV));
            }
            log("client write key: " + Hex.b2s(clientWriteKey));
            log("client write iv : " + Hex.b2s(clientWriteIV));
            log("server write key: " + Hex.b2s(serverWriteKey));
            log("server write iv : " + Hex.b2s(serverWriteIV));

        } catch (GeneralSecurityException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Set the connect cipher states to null and writeSeqNums to zero.  Create
     * new socket connection to server
     */
    public void setSocket(Socket s) throws TLSException {
        try {
            clientWriteCipherIsNull = true;
            serverWriteCipherIsNull = true;
            clientWriteSeqNum = 0;
            serverWriteSeqNum = 0;
            outs = new BufferedOutputStream(s.getOutputStream());
            ins = new BufferedInputStream(s.getInputStream());
        } catch (Exception e) {
            e.printStackTrace();
            throw new TLSException("Error in Record.reset(): " + e.getMessage());
        }
    }

    /**
     * Return the value of the available method on the underlying socket
     * connection.
     */
    public boolean available() {
        return readBufOffset > 0;
    }

    public void close() throws IOException {
        sendMessage(CONTENTTYPE_ALERT, new byte[] { ALERT_WARNING, ALERT_CLOSE_NOTIFY });
        outs.close();
    }

    /**
     * Return the MAC of the given byte array using the protocol specified
     * in TLSSocket
     *
     * @param end Either SecurityParameters.CONNECTIONEND_CLIENT or
     *      SecurityParameters.CONNECTIONEND_SERVER.  Indicates whether to use
     *      ServerWrite keys or ClientWrite keys for calculating MAC
     * @param type      content type of message
     * @param buf       the byte array containing the message to get the MAC of.
     * @param offset    where the message starts
     * @param length    the length of the message
     *
     * @return the MAC
     */
    private byte[] getMAC(HMAC hmac, byte[] seqNum, byte type, byte[] message, int offset, int length) {

        // concatenate all values to be MACed,
        // seqNum (8) + ContentType (1) + version (2) + message vector (2 len, msg.length)
        byte[] input = new byte[13 + length];
        System.arraycopy(seqNum, 0, input, 0, 8);
        input[8] = type;
        System.arraycopy(TLSSocket.PROTOCOL_VERSION, 0, input, 9, 2);
        input[11] = (byte) (length >> 8);
        input[12] = (byte) (length);
        System.arraycopy(message, offset, input, 13, length);

        log("input to mac: " + Hex.b2s(input));

        // MAC them
        return hmac.digest(input);
    }

    // Converts long to byte array
    private byte[] l2ba(long l) {
        byte[] byteVal = new byte[8];
        byteVal[7] = (byte) (l);
        byteVal[6] = (byte) (l >> 8);
        byteVal[5] = (byte) (l >> 16);
        byteVal[4] = (byte) (l >> 24);
        byteVal[3] = (byte) (l >> 32);
        byteVal[2] = (byte) (l >> 40);
        byteVal[1] = (byte) (l >> 48);
        byteVal[0] = (byte) (l >> 56);
        return byteVal;
    }

    private static byte[] sub(byte[] buf, int offset, int len) {
        byte[] result = new byte[len];
        System.arraycopy(buf, offset, result, 0, len);
        return result;
    }

    static void log(String msg) {
        if (log != null) {
            log.debug(msg);
        }
    }
}