com.cloudera.oryx.kmeans.serving.generation.VectorFactory.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.oryx.kmeans.serving.generation.VectorFactory.java

Source

/*
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.kmeans.serving.generation;

import com.cloudera.oryx.common.math.Vectors;
import com.google.common.base.Function;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import org.apache.commons.math3.linear.RealVector;
import org.dmg.pmml.Apply;
import org.dmg.pmml.ClusteringField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.LinearNorm;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;

import java.util.List;
import java.util.Map;
import java.util.Set;

public final class VectorFactory {
    private final boolean sparse;
    private final int fieldCount;
    private final List<Set<Update>> updates;

    public static VectorFactory create(MiningSchema schema, LocalTransformations transforms,
            List<ClusteringField> fields) {
        Map<FieldName, DerivedField> derived = Maps
                .newHashMapWithExpectedSize(transforms.getDerivedFields().size());
        for (DerivedField df : transforms.getDerivedFields()) {
            derived.put(df.getName(), df);
        }

        Multimap<FieldName, NumericUpdate> numeric = HashMultimap.create();
        Map<FieldName, Map<String, Update>> categorical = Maps.newHashMap();
        for (int j = 0; j < fields.size(); j++) {
            ClusteringField cf = fields.get(j);
            FieldName fn = cf.getField();
            if (derived.containsKey(fn)) {
                DerivedField df = derived.get(fn);
                Expression e = df.getExpression();
                if (e instanceof NormDiscrete) {
                    NormDiscrete nd = (NormDiscrete) e;
                    Map<String, Update> m = categorical.get(nd.getField());
                    if (m == null) {
                        m = Maps.newHashMap();
                        categorical.put(nd.getField(), m);
                    }
                    m.put(nd.getValue(), new NumericUpdate(ONE, j, cf.getFieldWeight()));
                } else if (e instanceof Apply) {
                    Apply apply = (Apply) e;
                    if (!"ln".equals(apply.getFunction())) {
                        throw new UnsupportedOperationException(
                                "Unsupported function type: " + apply.getFunction());
                    }
                    FieldName f = ((FieldRef) apply.getExpressions().get(0)).getField();
                    numeric.put(f, new NumericUpdate(LOG_VALUE, j, cf.getFieldWeight()));
                } else if (e instanceof NormContinuous) {
                    NormContinuous nc = (NormContinuous) e;
                    FieldName f = nc.getField();
                    LinearNorm l1 = nc.getLinearNorms().get(0);
                    LinearNorm l2 = nc.getLinearNorms().get(1);
                    InterpolateFunction ifunc = new InterpolateFunction(l1.getOrig(), l1.getNorm(), l2.getOrig(),
                            l2.getNorm());
                    numeric.put(f, new NumericUpdate(ifunc, j, cf.getFieldWeight()));
                } else {
                    throw new UnsupportedOperationException("Unsupported expression type: " + e);
                }
            } else {
                numeric.put(fn, new NumericUpdate(VALUE, j, cf.getFieldWeight()));
            }
        }

        boolean sparse = 2 * schema.getMiningFields().size() <= fields.size();
        List<Set<Update>> updates = Lists.newArrayListWithExpectedSize(schema.getMiningFields().size());
        for (MiningField mf : schema.getMiningFields()) {
            FieldName fn = mf.getName();
            if (numeric.containsKey(fn)) {
                updates.add(ImmutableSet.<Update>copyOf(numeric.get(fn)));
            } else if (categorical.containsKey(fn)) {
                CategoricalUpdate u = new CategoricalUpdate(categorical.get(fn));
                updates.add(ImmutableSet.<Update>of(u));
            }
        }
        return new VectorFactory(sparse, fields.size(), updates);
    }

    private VectorFactory(boolean sparse, int fieldCount, List<Set<Update>> updates) {
        this.sparse = sparse;
        this.fieldCount = fieldCount;
        this.updates = updates;
    }

    public RealVector createVector(String[] tokens) {
        if (tokens.length != updates.size()) {
            return null;
        }
        RealVector v = sparse ? Vectors.sparse(fieldCount) : Vectors.dense(fieldCount);
        for (int i = 0; i < tokens.length; i++) {
            for (Update update : updates.get(i)) {
                update.update(v, tokens[i]);
            }
        }
        return v;
    }

    private interface Update {
        void update(RealVector v, String value);
    }

    private static final class CategoricalUpdate implements Update {
        private final Map<String, Update> updates;

        CategoricalUpdate(Map<String, Update> updates) {
            this.updates = updates;
        }

        @Override
        public void update(RealVector v, String value) {
            updates.get(value).update(v, value);
        }
    }

    private static final class NumericUpdate implements Update {
        private final ConvertFunction c;
        private final int offset;
        private final double scale;

        NumericUpdate(ConvertFunction c, int offset, double scale) {
            this.c = c;
            this.offset = offset;
            this.scale = scale;
        }

        @Override
        public void update(RealVector v, String value) {
            v.setEntry(offset, c.apply(value) * scale);
        }
    }

    private interface ConvertFunction extends Function<String, Double> {
    }

    private static final ConvertFunction ONE = new ConvertFunction() {
        @Override
        public Double apply(String input) {
            return 1.0;
        }
    };

    private static final ConvertFunction VALUE = new ConvertFunction() {
        @Override
        public Double apply(String input) {
            try {
                return Double.valueOf(input);
            } catch (NumberFormatException e) {
                return Double.NaN;
            }
        }
    };
    private static final ConvertFunction LOG_VALUE = new ConvertFunction() {
        @Override
        public Double apply(String input) {
            try {
                return Math.log(Double.valueOf(input));
            } catch (NumberFormatException e) {
                return Double.NaN;
            }
        }
    };

    private static final class InterpolateFunction implements ConvertFunction {
        private final double a1;
        private final double b1;
        private final double a2;
        private final double b2;

        private InterpolateFunction(double a1, double b1, double a2, double b2) {
            this.a1 = a1;
            this.b1 = b1;
            this.a2 = a2;
            this.b2 = b2;
        }

        @Override
        public Double apply(String input) {
            try {
                double x = Double.valueOf(input);
                return b1 + ((x - a1) / (a2 - a1)) * (b2 - b1);
            } catch (NumberFormatException e) {
                return Double.NaN;
            }
        }
    }
}