mlbench.kmeans.KmeansUtils.java Source code

Java tutorial

Introduction

Here is the source code for mlbench.kmeans.KmeansUtils.java

Source

/**
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package mlbench.kmeans;

import mpi.MPI;
import mpi.MPIException;
import mpid.core.MPI_D;
import mpid.core.MPI_D_Exception;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.InputBuffer;
import org.apache.hadoop.io.OutputBuffer;
import org.apache.hadoop.io.serializer.Deserializer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.mahout.math.Vector;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.util.List;

public class KmeansUtils {

    static DataInputStream readFromHDFSF(Path path, JobConf conf) {
        InputStream in = null;
        try {
            FileSystem fs = getFileSystem(conf);
            in = fs.open(path);
        } catch (IOException e) {
            e.printStackTrace();
        }

        return new DataInputStream(in);
    }

    static OutputStream getOutputStream(String filePath, Configuration conf) {
        Path out = new Path(filePath);
        FSDataOutputStream output = null;
        try {
            if (getFileSystem(conf).exists(out)) {
                getFileSystem(conf).makeQualified(out);
            }
            output = getFileSystem(conf).create(out, true);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return output;
    }

    static FileSystem getFileSystem(Configuration conf) {
        FileSystem fs = null;
        try {
            fs = FileSystem.get(conf);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return fs;
    }

    static void accumulate(double[] sum, double[] vals) throws MPI_D_Exception {
        if (sum.length != vals.length) {
            throw new MPI_D_Exception("Array is incorrent!");
        }
        for (int i = 0; i < sum.length; i++) {
            sum[i] += vals[i];
        }
    }

    static void accumulate(double[] sum, Vector vector) throws MPI_D_Exception {
        if (sum.length != vector.size()) {
            throw new MPI_D_Exception("Array is incorrent!");
        }
        for (int i = 0; i < sum.length; i++) {
            sum[i] += vector.get(i);
        }
    }

    static double distance(double[] d1, double[] d2) throws MPI_D_Exception {
        double distance = 0;
        int len = d1.length < d2.length ? d1.length : d2.length;
        for (int i = 0; i < len; i++) {
            distance += Math.abs(d1[i] - d2[i]);
            if (distance < 0) {
                throw new MPI_D_Exception("Distance is out of bound!");
            }
        }
        return distance;
    }

    static class CenterTransfer {
        private SerializationFactory serializationFactory;
        private Serializer<PointVector> serialize;
        private Deserializer<PointVector> deserializer;
        private final int INTSIZE = Integer.SIZE >> 3;
        private final int PART_BUFFER_LENGTH = 1 << 15;
        private int buffSize = -1;
        private JobConf config;
        private int rank;
        private int size;

        public CenterTransfer(JobConf conf, int rank, int size) {
            this.config = conf;
            this.rank = rank;
            this.size = size;
            this.serializationFactory = new SerializationFactory(config);
            this.serialize = serializationFactory.getSerializer(PointVector.class);
            this.deserializer = serializationFactory.getDeserializer(PointVector.class);
        }

        /**
         * use full buffer
         *
         * @param data
         * @throws IOException
         */
        void deserialize(byte[] data, int groupSize, List<PointVector> centers) throws IOException {
            for (int k = 0; k < groupSize; k++) {
                IntBuffer ib = ByteBuffer.wrap(data, k * PART_BUFFER_LENGTH, PART_BUFFER_LENGTH).asIntBuffer();
                int bufSize = ib.get();
                int len = ib.get();
                InputBuffer in = new InputBuffer();
                in.reset(data, k * PART_BUFFER_LENGTH + (INTSIZE << 1), len);
                deserializer.open(in);

                for (int i = 0; i < bufSize; i++) {
                    PointVector point = (PointVector) ReflectionUtils.newInstance(PointVector.class, config);
                    point = deserializer.deserialize(point);
                    centers.add(point);
                }
                deserializer.close();
            }
        }

        void deserialize(byte[] data, List<PointVector> centers) throws IOException {
            centers.clear();
            IntBuffer ib = ByteBuffer.wrap(data, 0, PART_BUFFER_LENGTH).asIntBuffer();
            int centerSize = ib.get();
            int len = ib.get();
            InputBuffer in = new InputBuffer();
            in.reset(data, INTSIZE + INTSIZE, len);
            deserializer.open(in);

            for (int i = 0; i < centerSize; i++) {
                PointVector point = (PointVector) ReflectionUtils.newInstance(PointVector.class, config);
                point = deserializer.deserialize(point);
                centers.add(point);
            }
            deserializer.close();
        }

        void deserialize2(byte[] data, int groupSize, List<PointVector> centers) throws IOException {
            for (int k = 0; k < groupSize; k++) {
                IntBuffer ib = ByteBuffer.wrap(data, k * PART_BUFFER_LENGTH, PART_BUFFER_LENGTH).asIntBuffer();
                int bufSize = ib.get();
                int len = ib.get();
                InputBuffer in = new InputBuffer();
                in.reset(data, k * PART_BUFFER_LENGTH + (INTSIZE << 1), len);
                deserializer.open(in);

                for (int i = 0; i < bufSize; i++) {
                    PointVector point = (PointVector) ReflectionUtils.newInstance(PointVector.class, config);
                    point = deserializer.deserialize(point);
                    centers.add(point);
                }
                deserializer.close();
            }
        }

        /**
         * use part buffer size
         *
         * @return
         * @throws IOException
         */
        byte[] serializer(List<PointVector> centers) throws IOException {
            OutputBuffer out = new OutputBuffer();
            serialize.open(out);
            for (PointVector p : centers) {
                serialize.serialize(p);
            }
            serialize.close();
            byte[] buff = out.getData();
            int len = out.getLength();

            byte[] ds = new byte[PART_BUFFER_LENGTH];
            IntBuffer ib = ByteBuffer.wrap(ds).asIntBuffer();
            ib.put(centers.size());
            ib.put(len);
            System.arraycopy(buff, 0, ds, INTSIZE << 1, len);

            return ds;
        }

        /**
         * gather the centers by P2p mode
         */
        void gatherCentersByP2P(List<PointVector> centers) {
            buffSize = size * PART_BUFFER_LENGTH;

            try {
                if (rank != 0) {
                    byte[] outBuffer = serializer(centers);

                    MPI_D.COMM_BIPARTITE_A.Send(outBuffer, 0, PART_BUFFER_LENGTH, MPI.BYTE, 0, 0);
                }
                if (rank == 0) {
                    byte[] inBuffer = new byte[buffSize];
                    for (int i = 1; i < size; i++) {
                        MPI_D.COMM_BIPARTITE_A.Recv(inBuffer, i * PART_BUFFER_LENGTH, PART_BUFFER_LENGTH, MPI.BYTE,
                                i, 0);
                    }
                    deserialize(inBuffer, size, centers);
                }
            } catch (IOException | MPIException e) {
                e.printStackTrace();
            }
        }

        void broadcastCenters(List<PointVector> centers) throws IOException, MPIException {
            byte[] ds = new byte[PART_BUFFER_LENGTH];
            if (rank == 0) {
                OutputBuffer out = new OutputBuffer();
                serialize.open(out);
                for (PointVector p : centers) {
                    serialize.serialize(p);
                }
                serialize.close();
                byte[] buff = out.getData();
                int len = out.getLength();
                IntBuffer ib = ByteBuffer.wrap(ds).asIntBuffer();
                ib.put(centers.size());
                ib.put(len);
                System.arraycopy(buff, 0, ds, INTSIZE << 1, len);
            }

            MPI_D.COMM_BIPARTITE_O.Bcast(ds, 0, PART_BUFFER_LENGTH, MPI.BYTE, 0);
            deserialize(ds, centers);
        }
    }

    static class EmptyReport implements Reporter {

        @Override
        public void setStatus(String status) {

        }

        @Override
        public Counters.Counter getCounter(Enum<?> name) {
            return null;
        }

        @Override
        public Counters.Counter getCounter(String group, String name) {
            return null;
        }

        @Override
        public void incrCounter(Enum<?> key, long amount) {

        }

        @Override
        public void incrCounter(String group, String counter, long amount) {

        }

        @Override
        public InputSplit getInputSplit() throws UnsupportedOperationException {
            return null;
        }

        @Override
        public float getProgress() {
            return 0;
        }

        @Override
        public void progress() {

        }
    }
}