Java tutorial
/* * Copyright (c) 2015 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package numpy.core; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import com.google.common.io.ByteStreams; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import com.google.common.primitives.UnsignedInts; import net.razorvine.pickle.Unpickler; import net.razorvine.serpent.Parser; import net.razorvine.serpent.ast.Ast; import numpy.DType; import org.jpmml.converter.ValueUtil; import org.jpmml.sklearn.TupleUtil; public class NDArrayUtil { private NDArrayUtil() { } static public int[] getShape(NDArray array) { Object[] shape = array.getShape(); List<? extends Number> values = (List) Arrays.asList(shape); return Ints.toArray(ValueUtil.asIntegers(values)); } /** * Gets the payload of a one-dimensional array. */ static public List<?> getContent(NDArray array) { Object content = array.getContent(); return asJavaList(array, (List<?>) content); } /** * Gets the payload of the specified dimension of a multi-dimensional array. * * @param key The dimension. */ static public List<?> getContent(NDArray array, String key) { Map<String, ?> content = (Map<String, ?>) array.getContent(); return asJavaList(array, (List<?>) content.get(key)); } static private <E> List<E> asJavaList(NDArray array, List<E> values) { boolean fortranOrder = array.getFortranOrder(); if (fortranOrder) { int[] shape = getShape(array); switch (shape.length) { case 1: return values; case 2: return toJavaList(values, shape[0], shape[1]); default: throw new IllegalArgumentException(); } } return values; } /** * Translates a column-major (ie. Fortran-type) array to a row-major (ie. C-type) array. */ static private <E> List<E> toJavaList(List<E> values, int rows, int columns) { List<E> result = new ArrayList<>(values.size()); for (int i = 0; i < values.size(); i++) { int row = i / columns; int column = i % columns; E value = values.get((column * rows) + row); result.add(value); } return result; } /** * http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html */ static public NDArray parseNpy(InputStream is) throws IOException { byte[] magicBytes = new byte[MAGIC_STRING.length]; ByteStreams.readFully(is, magicBytes); if (!Arrays.equals(magicBytes, MAGIC_STRING)) { throw new IOException(); } int majorVersion = readUnsignedByte(is); int minorVersion = readUnsignedByte(is); if (majorVersion != 1 || minorVersion != 0) { throw new IOException(); } int headerLength = readUnsignedShort(is, ByteOrder.LITTLE_ENDIAN); if (headerLength < 0) { throw new IOException(); } byte[] headerBytes = new byte[headerLength]; ByteStreams.readFully(is, headerBytes); String header = new String(headerBytes); // Remove trailing whitespace header = header.trim(); Map<String, ?> headerDict = parseDict(header); Object descr = headerDict.get("descr"); Boolean fortranOrder = (Boolean) headerDict.get("fortran_order"); Object[] shape = (Object[]) headerDict.get("shape"); byte[] data = ByteStreams.toByteArray(is); NDArray result = new NDArray(); result.__setstate__( new Object[] { Arrays.asList(majorVersion, minorVersion), shape, descr, fortranOrder, data }); return result; } static public Object parseData(InputStream is, Object descr, Object[] shape) throws IOException { if (descr instanceof DType) { DType dType = (DType) descr; descr = dType.toDescr(); } int length = 1; for (int i = 0; i < shape.length; i++) { length *= ValueUtil.asInt((Number) shape[i]); } // End if if (descr instanceof String) { return parseArray(is, (String) descr, length); } List<Object[]> dims = (List<Object[]>) descr; Map<String, List<?>> result = new LinkedHashMap<>(); List<Object[]> objects = parseMultiArray(is, (List<String>) TupleUtil.extractElement(dims, 1), length); for (int i = 0; i < dims.size(); i++) { Object[] dim = dims.get(i); result.put((String) dim[0], TupleUtil.extractElement(objects, i)); } return result; } static public List<Object> parseArray(InputStream is, String descr, int length) throws IOException { List<Object> result = new ArrayList<>(length); TypeDescriptor descriptor = new TypeDescriptor(descr); while (result.size() < length) { Object element = descriptor.read(is); if (descriptor.isObject()) { NDArray array = (NDArray) element; result.addAll(NDArrayUtil.getContent(array)); continue; } result.add(element); } return result; } static public List<Object[]> parseMultiArray(InputStream is, List<String> descrs, int length) throws IOException { List<Object[]> result = new ArrayList<>(length); List<TypeDescriptor> descriptors = new ArrayList<>(); for (String descr : descrs) { TypeDescriptor descriptor = new TypeDescriptor(descr); if (descriptor.isObject()) { throw new IllegalArgumentException(descr); } descriptors.add(descriptor); } for (int i = 0; i < length; i++) { Object[] element = new Object[descriptors.size()]; for (int j = 0; j < descriptors.size(); j++) { TypeDescriptor descriptor = descriptors.get(j); element[j] = descriptor.read(is); } result.add(element); } return result; } static private Map<String, ?> parseDict(String string) { Parser parser = new Parser(); Ast ast = parser.parse(string); return (Map<String, ?>) ast.getData(); } static private byte readByte(InputStream is) throws IOException { int b = is.read(); if (b < 0) { throw new EOFException(); } return (byte) b; } static private int readUnsignedByte(InputStream is) throws IOException { int b = is.read(); if (b < 0) { throw new EOFException(); } return b; } static private int readUnsignedShort(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) { return Ints.fromBytes((byte) 0, (byte) 0, b1, b2); } else if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) { return Ints.fromBytes((byte) 0, (byte) 0, b2, b1); } throw new IOException(); } static private int readInt(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); byte b3 = readByte(is); byte b4 = readByte(is); if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) { return Ints.fromBytes(b1, b2, b3, b4); } else if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) { return Ints.fromBytes(b4, b3, b2, b1); } throw new IOException(); } static private long readLong(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); byte b3 = readByte(is); byte b4 = readByte(is); byte b5 = readByte(is); byte b6 = readByte(is); byte b7 = readByte(is); byte b8 = readByte(is); if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) { return Longs.fromBytes(b1, b2, b3, b4, b5, b6, b7, b8); } else if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) { return Longs.fromBytes(b8, b7, b6, b5, b4, b3, b2, b1); } throw new IOException(); } static private float readFloat(InputStream is, ByteOrder byteOrder) throws IOException { return Float.intBitsToFloat(readInt(is, byteOrder)); } static private double readDouble(InputStream is, ByteOrder byteOrder) throws IOException { return Double.longBitsToDouble(readLong(is, byteOrder)); } static private Object readObject(InputStream is) throws IOException { Unpickler unpickler = new Unpickler(); return unpickler.load(is); } static private String readString(InputStream is, int size) throws IOException { byte[] buffer = new byte[size]; ByteStreams.readFully(is, buffer); return toString(buffer, "UTF-8"); } static private String readUnicode(InputStream is, ByteOrder byteOrder, int size) throws IOException { byte[] buffer = new byte[size * 4]; ByteStreams.readFully(is, buffer); if ((ByteOrder.BIG_ENDIAN).equals(byteOrder)) { return toString(buffer, "UTF-32BE"); } else if ((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)) { return toString(buffer, "UTF-32LE"); } throw new IOException(); } static private String toString(byte[] buffer, String encoding) throws IOException { String string = new String(buffer, encoding); // Trim trailing zero characters while (string.length() > 0 && string.charAt(string.length() - 1) == '\0') { string = string.substring(0, string.length() - 1); } return string; } /** * http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html * http://docs.scipy.org/doc/numpy/reference/generated/numpy.dtype.byteorder.html */ static private class TypeDescriptor { private ByteOrder byteOrder = null; private Kind kind = null; private int size = 0; private TypeDescriptor(String descr) { int i = 0; ByteOrder byteOrder = null; switch (descr.charAt(i)) { // Native case '=': byteOrder = ByteOrder.nativeOrder(); i++; break; // Big-endian case '>': byteOrder = ByteOrder.BIG_ENDIAN; i++; break; // Little-endian case '<': byteOrder = ByteOrder.LITTLE_ENDIAN; i++; break; // Not applicable case '|': i++; break; } setByteOrder(byteOrder); Kind kind = Kind.forChar(descr.charAt(i)); i++; setKind(kind); if (i < descr.length()) { int size = Integer.parseInt(descr.substring(i)); setSize(size); } } public Object read(InputStream is) throws IOException { Kind kind = getKind(); ByteOrder byteOrder = getByteOrder(); int size = getSize(); switch (kind) { case BOOLEAN: { switch (size) { case 1: return (readByte(is) == 1); default: break; } } break; case INTEGER: { switch (size) { case 4: return readInt(is, byteOrder); case 8: return readLong(is, byteOrder); default: break; } } break; case UNSIGNED_INTEGER: { switch (size) { case 4: return UnsignedInts.toLong(readInt(is, byteOrder)); default: break; } } break; case FLOAT: { switch (size) { case 4: return readFloat(is, byteOrder); case 8: return readDouble(is, byteOrder); default: break; } } break; case OBJECT: { return readObject(is); } case STRING: { return readString(is, size); } case UNICODE: { return readUnicode(is, byteOrder, size); } case VOID: { byte[] buffer = new byte[size]; ByteStreams.readFully(is, buffer); return buffer; } default: break; } throw new IOException(); } public boolean isObject() { Kind kind = getKind(); switch (kind) { case OBJECT: return true; default: return false; } } public ByteOrder getByteOrder() { return this.byteOrder; } private void setByteOrder(ByteOrder byteOrder) { this.byteOrder = byteOrder; } public Kind getKind() { return this.kind; } private void setKind(Kind kind) { this.kind = kind; } public int getSize() { return this.size; } private void setSize(int size) { this.size = size; } static private enum Kind { BOOLEAN, INTEGER, UNSIGNED_INTEGER, FLOAT, COMPLEX_FLOAT, OBJECT, STRING, UNICODE, VOID,; static public Kind forChar(char c) { switch (c) { case 'b': return BOOLEAN; case 'i': return INTEGER; case 'u': return UNSIGNED_INTEGER; case 'f': return FLOAT; case 'c': return COMPLEX_FLOAT; case 'O': return OBJECT; case 'S': case 'a': return STRING; case 'U': return UNICODE; case 'V': return VOID; default: throw new IllegalArgumentException(); } } } } private static final byte[] MAGIC_STRING = { (byte) '\u0093', 'N', 'U', 'M', 'P', 'Y' }; }