com.cloudera.oryx.ml.speed.als.MockModelUpdateGenerator.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.ml.speed.als.MockModelUpdateGenerator.java

Source

/*
 * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.ml.speed.als;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.math3.random.RandomGenerator;
import org.dmg.pmml.PMML;

import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.math.VectorMath;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.kafka.util.RandomDatumGenerator;

public final class MockModelUpdateGenerator implements RandomDatumGenerator<String, String> {

    private static final ObjectMapper MAPPER = new ObjectMapper();

    /*
     A = [ 1 0 0 1 0 ; 0 1 0 1 1 ; 1 1 1 1 0 ; 0 0 1 0 0 ]
     */
    public static final Map<String, Collection<String>> A = new HashMap<>();
    static {
        A.put("6", Arrays.asList("1", "4"));
        A.put("7", Arrays.asList("2", "4", "5"));
        A.put("8", Arrays.asList("1", "2", "3", "4"));
        A.put("9", Arrays.asList("2"));
    }
    public static final Map<String, Collection<String>> At = new HashMap<>();
    static {
        At.put("1", Arrays.asList("6", "8"));
        At.put("2", Arrays.asList("7", "8", "9"));
        At.put("3", Arrays.asList("8"));
        At.put("4", Arrays.asList("6", "7", "8"));
        At.put("5", Arrays.asList("7"));
    }

    /*
     [U,S,V] = svd(A)
        
     U = U(:,1:2)
     S = S(1:2,1:2)
     V = V(:,1:2)
        
     X = U*sqrt(S)
     Y = V*sqrt(S)
     */
    public static final Map<String, float[]> X = buildMatrix(6,
            new double[][] { { -0.679001918401210, 0.173232408449017 }, { -0.823244234718400, -0.920085196137775 },
                    { -1.186534432549093, 0.446318558864201 }, { -0.207895139404806, 0.530350819368002 }, });
    public static final Map<String, float[]> Y = buildMatrix(1,
            new double[][] { { -0.720323513289685, 0.456546350776373 }, { -0.776018558846806, -0.349118056105777 },
                    { -0.538419102792183, 0.719706471415318 }, { -1.038195732260794, -0.221463305994448 },
                    { -0.317872218971108, -0.678009656770822 }, });

    @Override
    public Pair<String, String> generate(int id, RandomGenerator random) throws IOException {
        if (id % 10 == 0) {
            PMML pmml = PMMLUtils.buildSkeletonPMML();
            PMMLUtils.addExtension(pmml, "features", "2");
            PMMLUtils.addExtension(pmml, "implicit", "true");
            PMMLUtils.addExtensionContent(pmml, "XIDs", X.keySet());
            PMMLUtils.addExtensionContent(pmml, "YIDs", Y.keySet());
            return new Pair<>("MODEL", PMMLUtils.toString(pmml));
        } else {
            int xOrYID = id % 10;
            String xOrYIDString = Integer.toString(xOrYID);
            String message;
            boolean isX = xOrYID >= 6;
            if (isX) {
                message = MAPPER.writeValueAsString(
                        Arrays.asList("X", xOrYIDString, X.get(xOrYIDString), A.get(xOrYIDString)));
            } else {
                message = MAPPER.writeValueAsString(
                        Arrays.asList("Y", xOrYIDString, Y.get(xOrYIDString), At.get(xOrYIDString)));
            }
            return new Pair<>("UP", message);
        }
    }

    static Map<String, float[]> buildMatrix(int startIndex, double[]... rows) {
        Map<String, float[]> matrix = new HashMap<>(rows.length);
        int index = startIndex;
        for (double[] row : rows) {
            matrix.put(Integer.toString(index), VectorMath.toFloats(row));
            index++;
        }
        return Collections.unmodifiableMap(matrix);
    }

}