org.diqube.remote.cluster.RIntermediateAggregationResultUtil.java Source code

Java tutorial

Introduction

Here is the source code for org.diqube.remote.cluster.RIntermediateAggregationResultUtil.java

Source

/**
 * diqube: Distributed Query Base.
 *
 * Copyright (C) 2015 Bastian Gloeckle
 *
 * This file is part of diqube.
 *
 * diqube is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.diqube.remote.cluster;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.diqube.data.column.ColumnType;
import org.diqube.function.IntermediaryResult;
import org.diqube.function.aggregate.result.IntermediaryResultValueIterator;
import org.diqube.function.aggregate.result.serialization.IntermediateResultSerialization;
import org.diqube.function.aggregate.result.serialization.IntermediateResultSerializationResolver;
import org.diqube.remote.cluster.thrift.RColumnType;
import org.diqube.remote.cluster.thrift.RIntermediateAggregationResult;
import org.diqube.remote.cluster.thrift.RIntermediateAggregationResultValue;
import org.diqube.thrift.base.thrift.RValue;
import org.diqube.thrift.base.util.RValueUtil;
import org.diqube.util.SafeObjectInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.ImmutableSet;
import com.google.common.reflect.ClassPath;
import com.google.common.reflect.ClassPath.ClassInfo;

/**
 * Util for {@link RIntermediateAggregationResult}.
 * 
 * Serialization/deserialization adheres to {@link IntermediateResultSerialization}.
 *
 * @author Bastian Gloeckle
 */
public class RIntermediateAggregationResultUtil {
    private static final String ROOT_PKG = "org.diqube";

    private static final Logger logger = LoggerFactory.getLogger(RIntermediateAggregationResultUtil.class);

    private volatile static Set<String> whitelistedSerializableClassNames = null;

    /**
     * Deserialize a {@link RIntermediateAggregationResult} to a {@link IntermediaryResult}.
     * 
     * @throws IllegalArgumentException
     *           if data cannot be deserialized.
     */
    public static IntermediaryResult buildIntermediateAggregationResult(RIntermediateAggregationResult input)
            throws IllegalArgumentException {
        if (whitelistedSerializableClassNames == null)
            initialize();

        ColumnType type = null;
        if (input.isSetInputColumnType()) {
            switch (input.getInputColumnType()) {
            case LONG:
                type = ColumnType.LONG;
                break;
            case DOUBLE:
                type = ColumnType.DOUBLE;
                break;
            default:
                type = ColumnType.STRING;
                break;
            }
        }

        IntermediaryResult res = new IntermediaryResult(input.getOutputColName(), type);

        for (RIntermediateAggregationResultValue val : input.getValues()) {
            if (val.isSetValue()) {
                res.pushValue(RValueUtil.createValue(val.getValue()));
            } else {
                byte[] serialized = val.getSerialized();
                try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized)) {
                    try (ObjectInputStream ois = new SafeObjectInputStream(bais,
                            whitelistedSerializableClassNames)) {
                        res.pushValue(ois.readObject());
                    }
                } catch (IOException | ClassNotFoundException e) {
                    logger.error("Could not deserialize intermediate result", e);
                    throw new IllegalArgumentException("Could not deserialize intermediate result", e);
                }
            }
        }

        return res;
    }

    /**
     * Serialize a {@link IntermediaryResult}.
     * 
     * @throws IllegalArgumentException
     *           If cannot be serialized
     */
    public static RIntermediateAggregationResult buildRIntermediateAggregationResult(IntermediaryResult input)
            throws IllegalArgumentException {
        if (whitelistedSerializableClassNames == null)
            initialize();

        RIntermediateAggregationResult res = new RIntermediateAggregationResult();
        res.setOutputColName(input.getOutputColName());
        if (input.getInputColumnType() != null) {
            switch (input.getInputColumnType()) {
            case STRING:
                res.setInputColumnType(RColumnType.STRING);
                break;
            case LONG:
                res.setInputColumnType(RColumnType.LONG);
                break;
            case DOUBLE:
                res.setInputColumnType(RColumnType.DOUBLE);
                break;
            }
        }

        List<RIntermediateAggregationResultValue> values = new ArrayList<>();
        IntermediaryResultValueIterator it = input.createValueIterator();
        while (it.hasNext()) {
            Object valueObject = it.next();

            RIntermediateAggregationResultValue resValue = new RIntermediateAggregationResultValue();

            RValue rvalue = RValueUtil.createRValue(valueObject);
            if (rvalue != null) {
                resValue.setValue(rvalue);
            } else {
                if (!whitelistedSerializableClassNames.contains(valueObject.getClass().getName()))
                    // only a shallow check, but better than no check at all.
                    throw new IllegalArgumentException(
                            "Class " + valueObject.getClass().getName() + " is not whitelisted.");

                try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
                    try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
                        oos.writeObject(valueObject);
                    }

                    resValue.setSerialized(baos.toByteArray());
                } catch (IOException e) {
                    logger.error("Could not serialize intermediary result", e);
                    throw new IllegalArgumentException("Could not serialize intermediary result", e);
                }
            }

            values.add(resValue);
        }
        res.setValues(values);

        return res;
    }

    private synchronized static void initialize() {
        if (whitelistedSerializableClassNames != null)
            return;

        ClassPath cp;
        try {
            cp = ClassPath.from(RIntermediateAggregationResultUtil.class.getClassLoader());
        } catch (IOException e) {
            throw new RuntimeException("Could not initialize classpath scanning!", e);
        }
        ImmutableSet<ClassInfo> classInfos = cp.getTopLevelClassesRecursive(ROOT_PKG);

        whitelistedSerializableClassNames = new HashSet<>();

        for (ClassInfo classInfo : classInfos) {
            Class<?> clazz = classInfo.load();
            if (clazz.getAnnotation(IntermediateResultSerialization.class) != null) {
                if (!IntermediateResultSerializationResolver.class.isAssignableFrom(clazz)) {
                    logger.warn("Class {} has {} annotation, but does not implement {}. Ignoring.", clazz.getName(),
                            IntermediateResultSerialization.class.getSimpleName(),
                            IntermediateResultSerializationResolver.class.getName());
                    continue;
                }

                try {
                    IntermediateResultSerializationResolver resolver = (IntermediateResultSerializationResolver) clazz
                            .newInstance();

                    resolver.resolve(cls -> {
                        whitelistedSerializableClassNames.add(cls.getName());
                        logger.debug(
                                "Whitelisted class {} for being de-/serialized for intermediate aggregation results",
                                cls);
                    });
                } catch (InstantiationException | IllegalAccessException e) {
                    logger.warn("Could not instantiate {}. Ignoring.", clazz.getName(), e);
                }
            }
        }
    }
}