edu.iu.daal_kmeans.regroupallgather.CenCalcTask.java Source code

Java tutorial

Introduction

Here is the source code for edu.iu.daal_kmeans.regroupallgather.CenCalcTask.java

Source

/*
 * Copyright 2013-2016 Indiana University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package edu.iu.daal_kmeans.regroupallgather;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import edu.iu.harp.schdynamic.Task;
import edu.iu.harp.partition.Partition;
import edu.iu.harp.partition.Table;
import edu.iu.harp.resource.DoubleArray;

public class CenCalcTask implements Task<double[], Object> {

    protected static final Log LOG = LogFactory.getLog(CenCalcTask.class);

    private double[][] centroids;
    private double[][] local;
    private final int cenVecSize;

    public CenCalcTask(Table<DoubleArray> cenTable, int cenVecSize) {
        centroids = new double[cenTable.getNumPartitions()][];
        local = new double[centroids.length][];
        for (Partition<DoubleArray> partition : cenTable.getPartitions()) {
            int partitionID = partition.id();
            DoubleArray array = partition.get();
            centroids[partitionID] = array.get();
            local[partitionID] = new double[array.size()];
        }
        this.cenVecSize = cenVecSize;
    }

    public void update(Table<DoubleArray> cenTable) {
        for (Partition<DoubleArray> partition : cenTable.getPartitions()) {
            int partitionID = partition.id();
            DoubleArray array = partition.get();
            centroids[partitionID] = array.get();
        }
    }

    public double[][] getLocal() {
        return local;
    }

    @Override
    public Object run(double[] points) throws Exception {
        for (int i = 0; i < points.length;) {
            i++;
            double minDistance = Double.MAX_VALUE;
            int minCenParID = 0;
            int minOffset = 0;
            for (int j = 0; j < centroids.length; j++) {
                for (int k = 0; k < local[j].length;) {
                    int pStart = i;
                    k++;
                    double distance = 0.0;
                    for (int l = 1; l < cenVecSize; l++) {
                        double diff = (points[pStart++] - centroids[j][k++]);
                        distance += diff * diff;
                    }
                    if (distance < minDistance) {
                        minDistance = distance;
                        minCenParID = j;
                        minOffset = k - cenVecSize;
                    }
                }
            }
            // Count + 1
            local[minCenParID][minOffset++]++;
            // Add the point
            for (int j = 1; j < cenVecSize; j++) {
                local[minCenParID][minOffset++] += points[i++];
            }
        }
        return null;
    }
}