numpy.core.NDArrayUtil.java Source code

Java tutorial

Introduction

Here is the source code for numpy.core.NDArrayUtil.java

Source

/*
 * Copyright (c) 2015 Villu Ruusmann
 *
 * This file is part of JPMML-SkLearn
 *
 * JPMML-SkLearn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SkLearn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SkLearn.  If not, see <http://www.gnu.org/licenses/>.
 */
package numpy.core;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import com.google.common.io.ByteStreams;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import com.google.common.primitives.UnsignedInts;
import net.razorvine.pickle.Unpickler;
import net.razorvine.serpent.Parser;
import net.razorvine.serpent.ast.Ast;
import numpy.DType;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.TupleUtil;

public class NDArrayUtil {

    private NDArrayUtil() {
    }

    static public int[] getShape(NDArray array) {
        Object[] shape = array.getShape();

        List<? extends Number> values = (List) Arrays.asList(shape);

        return Ints.toArray(ValueUtil.asIntegers(values));
    }

    /**
     * Gets the payload of a one-dimensional array.
     */
    static public List<?> getContent(NDArray array) {
        Object content = array.getContent();

        return asJavaList(array, (List<?>) content);
    }

    /**
     * Gets the payload of the specified dimension of a multi-dimensional array.
     *
     * @param key The dimension.
     */
    static public List<?> getContent(NDArray array, String key) {
        Map<String, ?> content = (Map<String, ?>) array.getContent();

        return asJavaList(array, (List<?>) content.get(key));
    }

    static private <E> List<E> asJavaList(NDArray array, List<E> values) {
        boolean fortranOrder = array.getFortranOrder();

        if (fortranOrder) {
            int[] shape = getShape(array);

            switch (shape.length) {
            case 1:
                return values;
            case 2:
                return toJavaList(values, shape[0], shape[1]);
            default:
                throw new IllegalArgumentException();
            }
        }

        return values;
    }

    /**
     * Translates a column-major (ie. Fortran-type) array to a row-major (ie. C-type) array.
     */
    static private <E> List<E> toJavaList(List<E> values, int rows, int columns) {
        List<E> result = new ArrayList<>(values.size());

        for (int i = 0; i < values.size(); i++) {
            int row = i / columns;
            int column = i % columns;

            E value = values.get((column * rows) + row);

            result.add(value);
        }

        return result;
    }

    /**
     * http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html
     */
    static public NDArray parseNpy(InputStream is) throws IOException {
        byte[] magicBytes = new byte[MAGIC_STRING.length];

        ByteStreams.readFully(is, magicBytes);

        if (!Arrays.equals(magicBytes, MAGIC_STRING)) {
            throw new IOException();
        }

        int majorVersion = readUnsignedByte(is);
        int minorVersion = readUnsignedByte(is);

        if (majorVersion != 1 || minorVersion != 0) {
            throw new IOException();
        }

        int headerLength = readUnsignedShort(is, ByteOrder.LITTLE_ENDIAN);

        if (headerLength < 0) {
            throw new IOException();
        }

        byte[] headerBytes = new byte[headerLength];

        ByteStreams.readFully(is, headerBytes);

        String header = new String(headerBytes);

        // Remove trailing whitespace
        header = header.trim();

        Map<String, ?> headerDict = parseDict(header);

        Object descr = headerDict.get("descr");
        Boolean fortranOrder = (Boolean) headerDict.get("fortran_order");
        Object[] shape = (Object[]) headerDict.get("shape");

        byte[] data = ByteStreams.toByteArray(is);

        NDArray result = new NDArray();

        result.__setstate__(
                new Object[] { Arrays.asList(majorVersion, minorVersion), shape, descr, fortranOrder, data });

        return result;
    }

    static public Object parseData(InputStream is, Object descr, Object[] shape) throws IOException {

        if (descr instanceof DType) {
            DType dType = (DType) descr;

            descr = dType.toDescr();
        }

        int length = 1;

        for (int i = 0; i < shape.length; i++) {
            length *= ValueUtil.asInt((Number) shape[i]);
        } // End if

        if (descr instanceof String) {
            return parseArray(is, (String) descr, length);
        }

        List<Object[]> dims = (List<Object[]>) descr;

        Map<String, List<?>> result = new LinkedHashMap<>();

        List<Object[]> objects = parseMultiArray(is, (List<String>) TupleUtil.extractElement(dims, 1), length);

        for (int i = 0; i < dims.size(); i++) {
            Object[] dim = dims.get(i);

            result.put((String) dim[0], TupleUtil.extractElement(objects, i));
        }

        return result;
    }

    static public List<Object> parseArray(InputStream is, String descr, int length) throws IOException {
        List<Object> result = new ArrayList<>(length);

        TypeDescriptor descriptor = new TypeDescriptor(descr);

        while (result.size() < length) {
            Object element = descriptor.read(is);

            if (descriptor.isObject()) {
                NDArray array = (NDArray) element;

                result.addAll(NDArrayUtil.getContent(array));

                continue;
            }

            result.add(element);
        }

        return result;
    }

    static public List<Object[]> parseMultiArray(InputStream is, List<String> descrs, int length)
            throws IOException {
        List<Object[]> result = new ArrayList<>(length);

        List<TypeDescriptor> descriptors = new ArrayList<>();

        for (String descr : descrs) {
            TypeDescriptor descriptor = new TypeDescriptor(descr);

            if (descriptor.isObject()) {
                throw new IllegalArgumentException(descr);
            }

            descriptors.add(descriptor);
        }

        for (int i = 0; i < length; i++) {
            Object[] element = new Object[descriptors.size()];

            for (int j = 0; j < descriptors.size(); j++) {
                TypeDescriptor descriptor = descriptors.get(j);

                element[j] = descriptor.read(is);
            }

            result.add(element);
        }

        return result;
    }

    static private Map<String, ?> parseDict(String string) {
        Parser parser = new Parser();

        Ast ast = parser.parse(string);

        return (Map<String, ?>) ast.getData();
    }

    static private byte readByte(InputStream is) throws IOException {
        int b = is.read();
        if (b < 0) {
            throw new EOFException();
        }

        return (byte) b;
    }

    static private int readUnsignedByte(InputStream is) throws IOException {
        int b = is.read();
        if (b < 0) {
            throw new EOFException();
        }

        return b;
    }

    static private int readUnsignedShort(InputStream is, ByteOrder byteOrder) throws IOException {
        byte b1 = readByte(is);
        byte b2 = readByte(is);

        if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) {
            return Ints.fromBytes((byte) 0, (byte) 0, b1, b2);
        } else

        if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) {
            return Ints.fromBytes((byte) 0, (byte) 0, b2, b1);
        }

        throw new IOException();
    }

    static private int readInt(InputStream is, ByteOrder byteOrder) throws IOException {
        byte b1 = readByte(is);
        byte b2 = readByte(is);
        byte b3 = readByte(is);
        byte b4 = readByte(is);

        if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) {
            return Ints.fromBytes(b1, b2, b3, b4);
        } else

        if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) {
            return Ints.fromBytes(b4, b3, b2, b1);
        }

        throw new IOException();
    }

    static private long readLong(InputStream is, ByteOrder byteOrder) throws IOException {
        byte b1 = readByte(is);
        byte b2 = readByte(is);
        byte b3 = readByte(is);
        byte b4 = readByte(is);
        byte b5 = readByte(is);
        byte b6 = readByte(is);
        byte b7 = readByte(is);
        byte b8 = readByte(is);

        if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) {
            return Longs.fromBytes(b1, b2, b3, b4, b5, b6, b7, b8);
        } else

        if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) {
            return Longs.fromBytes(b8, b7, b6, b5, b4, b3, b2, b1);
        }

        throw new IOException();
    }

    static private float readFloat(InputStream is, ByteOrder byteOrder) throws IOException {
        return Float.intBitsToFloat(readInt(is, byteOrder));
    }

    static private double readDouble(InputStream is, ByteOrder byteOrder) throws IOException {
        return Double.longBitsToDouble(readLong(is, byteOrder));
    }

    static private Object readObject(InputStream is) throws IOException {
        Unpickler unpickler = new Unpickler();

        return unpickler.load(is);
    }

    static private String readString(InputStream is, int size) throws IOException {
        byte[] buffer = new byte[size];

        ByteStreams.readFully(is, buffer);

        return toString(buffer, "UTF-8");
    }

    static private String readUnicode(InputStream is, ByteOrder byteOrder, int size) throws IOException {
        byte[] buffer = new byte[size * 4];

        ByteStreams.readFully(is, buffer);

        if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) {
            return toString(buffer, "UTF-32BE");
        } else

        if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) {
            return toString(buffer, "UTF-32LE");
        }

        throw new IOException();
    }

    static private String toString(byte[] buffer, String encoding) throws IOException {
        String string = new String(buffer, encoding);

        // Trim trailing zero characters
        while (string.length() > 0 && string.charAt(string.length() - 1) == '\0') {
            string = string.substring(0, string.length() - 1);
        }

        return string;
    }

    /**
     * http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
     * http://docs.scipy.org/doc/numpy/reference/generated/numpy.dtype.byteorder.html
     */
    static private class TypeDescriptor {

        private ByteOrder byteOrder = null;

        private Kind kind = null;

        private int size = 0;

        private TypeDescriptor(String descr) {
            int i = 0;

            ByteOrder byteOrder = null;

            switch (descr.charAt(i)) {
            // Native
            case '=':
                byteOrder = ByteOrder.nativeOrder();
                i++;
                break;
            // Big-endian
            case '>':
                byteOrder = ByteOrder.BIG_ENDIAN;
                i++;
                break;
            // Little-endian
            case '<':
                byteOrder = ByteOrder.LITTLE_ENDIAN;
                i++;
                break;
            // Not applicable
            case '|':
                i++;
                break;
            }

            setByteOrder(byteOrder);

            Kind kind = Kind.forChar(descr.charAt(i));

            i++;

            setKind(kind);

            if (i < descr.length()) {
                int size = Integer.parseInt(descr.substring(i));

                setSize(size);
            }
        }

        public Object read(InputStream is) throws IOException {
            Kind kind = getKind();
            ByteOrder byteOrder = getByteOrder();
            int size = getSize();

            switch (kind) {
            case BOOLEAN: {
                switch (size) {
                case 1:
                    return (readByte(is) == 1);
                default:
                    break;
                }
            }
                break;
            case INTEGER: {
                switch (size) {
                case 4:
                    return readInt(is, byteOrder);
                case 8:
                    return readLong(is, byteOrder);
                default:
                    break;
                }
            }
                break;
            case UNSIGNED_INTEGER: {
                switch (size) {
                case 4:
                    return UnsignedInts.toLong(readInt(is, byteOrder));
                default:
                    break;
                }
            }
                break;
            case FLOAT: {
                switch (size) {
                case 4:
                    return readFloat(is, byteOrder);
                case 8:
                    return readDouble(is, byteOrder);
                default:
                    break;
                }
            }
                break;
            case OBJECT: {
                return readObject(is);
            }
            case STRING: {
                return readString(is, size);
            }
            case UNICODE: {
                return readUnicode(is, byteOrder, size);
            }
            case VOID: {
                byte[] buffer = new byte[size];

                ByteStreams.readFully(is, buffer);

                return buffer;
            }
            default:
                break;
            }

            throw new IOException();
        }

        public boolean isObject() {
            Kind kind = getKind();

            switch (kind) {
            case OBJECT:
                return true;
            default:
                return false;
            }
        }

        public ByteOrder getByteOrder() {
            return this.byteOrder;
        }

        private void setByteOrder(ByteOrder byteOrder) {
            this.byteOrder = byteOrder;
        }

        public Kind getKind() {
            return this.kind;
        }

        private void setKind(Kind kind) {
            this.kind = kind;
        }

        public int getSize() {
            return this.size;
        }

        private void setSize(int size) {
            this.size = size;
        }

        static private enum Kind {
            BOOLEAN, INTEGER, UNSIGNED_INTEGER, FLOAT, COMPLEX_FLOAT, OBJECT, STRING, UNICODE, VOID,;

            static public Kind forChar(char c) {

                switch (c) {
                case 'b':
                    return BOOLEAN;
                case 'i':
                    return INTEGER;
                case 'u':
                    return UNSIGNED_INTEGER;
                case 'f':
                    return FLOAT;
                case 'c':
                    return COMPLEX_FLOAT;
                case 'O':
                    return OBJECT;
                case 'S':
                case 'a':
                    return STRING;
                case 'U':
                    return UNICODE;
                case 'V':
                    return VOID;
                default:
                    throw new IllegalArgumentException();
                }
            }
        }
    }

    private static final byte[] MAGIC_STRING = { (byte) '\u0093', 'N', 'U', 'M', 'P', 'Y' };
}