ml.shifu.core.util.ValidationUtils.java Source code

Java tutorial

Introduction

Here is the source code for ml.shifu.core.util.ValidationUtils.java

Source

/**
 * Copyright [2012-2014] eBay Software Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ml.shifu.core.util;

import ml.shifu.core.container.fieldMeta.Field;
import ml.shifu.core.container.fieldMeta.FieldBasics;
import ml.shifu.core.container.fieldMeta.FieldMeta;
import ml.shifu.core.exception.SizeMismatchException;
import ml.shifu.core.util.Constants.VAR_IMPORTANCE;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
 * {@link ValidationUtils} is used to for almost all kinds of utility function in this framework.
 */
public final class ValidationUtils {

    private static final Logger LOG = LoggerFactory.getLogger(ValidationUtils.class);

    /**
     * Avoid using new for our utility class.
     */
    private ValidationUtils() {
    }

    public static void validateTargetFieldExistsAndUnique(FieldMeta fieldMeta, String targetFieldName) {

        if (StringUtils.isEmpty(targetFieldName)) {
            throw new RuntimeException("Empty targetFieldName");
        }

        int cnt = 0;
        for (Field field : fieldMeta.getFields()) {
            if (field.getName().equals(targetFieldName)) {
                cnt++;
            }
        }

        if (cnt == 0) {
            throw new RuntimeException("Target field does not exist: " + targetFieldName);
        }

        if (cnt > 1) {
            throw new RuntimeException(
                    "Duplicated target field: " + targetFieldName + "(" + cnt + " duplicated fields)");
        }
    }

    public static void validateSizeMatch(String nameA, Double[] a, String nameB, Double[] b) {
        if (a.length != b.length) {
            throw new SizeMismatchException(nameA, a.length, nameB, b.length);
        }
    }

    /**
     * validate if given path param is valid & exists.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validatePath(String param, String path) {

        String errMessage = null;
        if (path.length() == 0) {
            errMessage = "Invalid param [" + param + "]+ value [" + path + "]";
            throw new RuntimeException(errMessage);
        }
        File file = new File(path);
        if (!file.exists()) {
            errMessage = "Required file [" + path + "] does not exist";
            throw new RuntimeException(errMessage);
        }
    }

    /**
     * validate if given variable importance metric is valid.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateVarImportanceMetric(String param, String metric, String maxVariables) {

        String errMessage = null;
        if (metric == null || metric.length() == 0) {
            errMessage = "Invalid param [" + param + "]+ value null/empty";
            throw new RuntimeException(errMessage);
        }

        // ensure var filter keys matches one of known list
        boolean found = false;

        for (VAR_IMPORTANCE v : VAR_IMPORTANCE.values()) {
            if (v.equals(VAR_IMPORTANCE.valueOf(metric))) {
                found = true;
                break;
            }
        }
        if (!found) {
            errMessage = "Invalid value for param [" + param + "]. Specified [" + metric + "] is invalid";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
        // validate value
        Integer valueInt;
        try {
            valueInt = Integer.parseInt(maxVariables);
        } catch (Exception e) {
            // not an integer
            errMessage = "Invalid value for param [" + param + "]. Specified [" + maxVariables
                    + "] is invalid for [" + metric + "]";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }

        if (!VAR_IMPORTANCE.valueOf(metric).equals(VAR_IMPORTANCE.NONE) && valueInt <= 0) {
            // not a value integer
            errMessage = "Invalid value for param [" + param + "]. Specified [" + maxVariables
                    + "] is invalid for [" + metric + "]";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }

    }

    /**
     * validate if given string.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateString(String str) {

        String errMessage = null;
        if (str == null || str.length() == 0) {
            errMessage = "Invalid string " + str;
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
    }

    /**
     * validate if given string.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateString(String param, String str) {

        String errMessage = null;
        if (str == null || str.length() == 0) {
            errMessage = "Invalid string [" + str + "] for param [" + param + "]. Check configuration.";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
    }

    /**
     * validate if given object null.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateNullObject(String param, Object obj) {

        String errMessage = null;
        if (obj == null) {
            errMessage = "Invalid/null object for param [" + param + "]";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
    }

    /**
     * validate if given opType is null.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateOpType(String field, FieldBasics.OpType opType) {

        String errMessage = null;
        if (opType == null) {
            errMessage = "Invalid/null object for opType for field " + field + " in FieldMeta";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
        if (!(opType.equals(FieldBasics.OpType.CONTINUOUS) || opType.equals(FieldBasics.OpType.CATEGORICAL)
                || opType.equals(FieldBasics.OpType.ORDINAL))) {
            errMessage = "Invalid opType for field " + field + " in FieldMeta";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
    }

    /**
     * validate fields in FieldMeta are valid.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateFields(FieldMeta fieldMeta) {

        // validate num fields
        // num_fields < MIN_NUM_FIELDS
        String errMessage = null;
        List<Field> fields = fieldMeta.getFields();
        if (fields == null || fields.size() < Constants.MIN_NUM_FIELDS) {
            errMessage = "Invalid number of fields in FieldMeta (< " + " " + Constants.MIN_NUM_FIELDS
                    + " ). Check header file/configuration.";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }

        // detect duplicate fields
        Set set = new HashSet();
        for (Field field : fields) {
            String name = field.getFieldBasics().getName();
            if (set.contains(name)) {
                errMessage = "Duplicate field name [" + name + "] found. Check header file/configuration.";
                LOG.error(errMessage);
                throw new RuntimeException(errMessage);
            }
            set.add(name);
        }

        // TODO any character not allowed in field name?

    }

    /**
     * check if field exists in given set.
     */
    public static Field fieldExists(List<Field> fieldList, Field field) {

        if (fieldList == null || fieldList.size() == 0) {
            return null;
        }

        for (Field sField : fieldList) {
            String name = sField.getFieldBasics().getName();
            if (name.equals(field.getFieldBasics().getName())) {
                return sField;
            }
        }
        return null;

    }

    /**
     * validate field in FieldMeta is valid.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateField(String param, String fieldName, FieldMeta fieldMeta) {

        // validate num fields
        // num_fields < MIN_NUM_FIELDS
        String errMessage = null;
        List<Field> fields = fieldMeta.getFields();
        if (fields == null || fields.size() < Constants.MIN_NUM_FIELDS) {
            errMessage = "Invalid number of fields in FieldMeta (< " + " " + Constants.MIN_NUM_FIELDS
                    + " ). Check header file.";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }

        // detect field
        for (Field field : fields) {
            String name = field.getFieldBasics().getName();
            if (name.equals(fieldName)) {
                // field found
                return;
            }
        }

        errMessage = "Field name [" + fieldName + "] not found in FieldMeta. Check your configuration param "
                + param + ".";
        LOG.error(errMessage);
        throw new RuntimeException(errMessage);

    }

    /**
     * validate field's unique existence.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateFieldUniqueExistence(String param, List<Field> listaFields, List<Field> listbFields,
            List<Field> listcFields, FieldMeta fieldMeta) {

        Set set = new HashSet();
        Set duplicateSet = new HashSet();
        String errMessage = null;

        List<Field> fieldMetaFields = fieldMeta.getFields();
        // populate field names to set for lookup
        Set fmSet = new HashSet();
        for (Field field : fieldMetaFields) {
            fmSet.add(field.getFieldBasics().getName());
        }
        Set nonExistentFieldsSet = new HashSet();

        if (listaFields != null) {
            for (Field field : listaFields) {
                if (set.contains(field.getFieldBasics().getName())) {

                    duplicateSet.add(field.getFieldBasics().getName());
                    //errMessage = "Duplicate field name ["+ field.getFieldBasics().getName() + "] found. Check "+param+
                    //      " list.";
                    //LOG.error(errMessage);
                    //throw new RuntimeException(errMessage);

                } else {
                    set.add(field.getFieldBasics().getName());
                }

                // check for fields that do not exist in fieldMeta
                if (!fmSet.contains(field.getFieldBasics().getName())) {
                    nonExistentFieldsSet.add(field.getFieldBasics().getName());
                }

            }
        }

        if (listbFields != null) {
            for (Field field : listbFields) {
                if (set.contains(field.getFieldBasics().getName())) {
                    duplicateSet.add(field.getFieldBasics().getName());

                    //errMessage = "Duplicate field name ["+ field.getFieldBasics().getName() + "] found. Check "+param+" list.";
                    //LOG.error(errMessage);
                    //throw new RuntimeException(errMessage);
                } else {
                    set.add(field.getFieldBasics().getName());
                }
                // check for fields that do not exist in fieldMeta
                if (!fmSet.contains(field.getFieldBasics().getName())) {
                    nonExistentFieldsSet.add(field.getFieldBasics().getName());
                }

            }
        }

        if (listcFields != null) {
            for (Field field : listcFields) {
                if (set.contains(field.getFieldBasics().getName())) {

                    duplicateSet.add(field.getFieldBasics().getName());

                    //errMessage = "Duplicate field name ["+ field.getFieldBasics().getName() + "] found. Check "+param+" list.";
                    //LOG.error(errMessage);
                    //throw new RuntimeException(errMessage);
                } else {
                    set.add(field.getFieldBasics().getName());
                }
                // check for fields that do not exist in fieldMeta
                if (!fmSet.contains(field.getFieldBasics().getName())) {
                    nonExistentFieldsSet.add(field.getFieldBasics().getName());
                }

            }
        }

        // check for duplicate set and abort if found
        if (duplicateSet.size() > 0) {
            errMessage = "Duplicate field names found. Check " + param + " list.";
            LOG.error(errMessage);
            Iterator iter = duplicateSet.iterator();
            StringBuffer strBuf = new StringBuffer();
            strBuf.append("\n");
            while (iter.hasNext()) {
                strBuf.append(iter.next() + "\n");
            }
            LOG.error("Duplicate fields:" + strBuf.toString());
            throw new RuntimeException(errMessage);
        }

        // check for nonExistentFieldSet and abort if found
        if (nonExistentFieldsSet.size() > 0) {
            errMessage = "Found field names that do not exist in the header. Check " + param + " list.";
            LOG.error(errMessage);
            Iterator iter = nonExistentFieldsSet.iterator();
            StringBuffer strBuf = new StringBuffer();
            strBuf.append("\n");
            while (iter.hasNext()) {
                strBuf.append(iter.next() + "\n");
            }
            LOG.error("Non existent fields:" + strBuf.toString());
            throw new RuntimeException(errMessage);
        }

    }

    /**
     * validate target field.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateTarget(FieldMeta fieldMeta, List<String> targetSelectors) {

        String errMessage = null;
        if (targetSelectors == null || targetSelectors.size() != 1) {
            errMessage = "Invalid value(s) for target field. Check target param";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
        // make sure it exists in fieldMeta
        for (Field field : fieldMeta.getFields()) {
            if (field.getFieldBasics().getName().equals(targetSelectors.get(0))) {
                return;
            }
        }
        errMessage = "Target field name " + targetSelectors.get(0) + " not found in FieldMeta. Check target param";
        LOG.error(errMessage);
        throw new RuntimeException(errMessage);

    }

    /**
     * validate selectors.
     *
     * @throws RuntimeException if invalid.
     */
    public static void validateSelectors(String param, List<String> targetSelectors) {

        String errMessage = null;
        if (targetSelectors == null) {
            errMessage = "Invalid value(s) for param " + param + ". Check configuration.";
            LOG.error(errMessage);
            throw new RuntimeException(errMessage);
        }
        // make sure it exists in fieldMeta
        for (String selector : targetSelectors) {

            if (selector == null || selector.trim().length() == 0) {
                errMessage = "Invalid value for param " + param + ". Check your configuration.";
                LOG.error(errMessage);
                throw new RuntimeException(errMessage);
            }

        }

    }

}