io.crate.analyze.where.EqualityExtractor.java Source code

Java tutorial

Introduction

Here is the source code for io.crate.analyze.where.EqualityExtractor.java

Source

/*
 * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
 * license agreements.  See the NOTICE file distributed with this work for
 * additional information regarding copyright ownership.  Crate licenses
 * this file to you under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.  You may
 * obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * However, if you have executed another commercial license agreement
 * with Crate these terms will supersede the license and you may use the
 * software solely pursuant to the terms of the relevant commercial agreement.
 */

package io.crate.analyze.where;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import io.crate.analyze.EvaluatingNormalizer;
import io.crate.analyze.symbol.*;
import io.crate.metadata.*;
import io.crate.operation.operator.EqOperator;
import io.crate.operation.operator.any.AnyEqOperator;
import io.crate.types.CollectionType;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import org.elasticsearch.common.io.stream.StreamOutput;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.*;

class EqualityExtractor {

    private static final Function NULL_MARKER = new Function(
            new FunctionInfo(new FunctionIdent("null_marker", ImmutableList.<DataType>of()), DataTypes.UNDEFINED),
            ImmutableList.<Symbol>of());
    private static final EqProxy NULL_MARKER_PROXY = new EqProxy(NULL_MARKER);

    private EvaluatingNormalizer normalizer;

    EqualityExtractor(EvaluatingNormalizer normalizer) {
        this.normalizer = normalizer;
    }

    List<List<Symbol>> extractParentMatches(List<ColumnIdent> columns, Symbol symbol,
            @Nullable TransactionContext transactionContext) {
        return extractMatches(columns, symbol, false, transactionContext);
    }

    List<List<Symbol>> extractExactMatches(List<ColumnIdent> columns, Symbol symbol,
            @Nullable TransactionContext transactionContext) {
        return extractMatches(columns, symbol, true, transactionContext);
    }

    private List<List<Symbol>> extractMatches(Collection<ColumnIdent> columns, Symbol symbol, boolean exact,
            @Nullable TransactionContext transactionContext) {
        EqualityExtractor.ProxyInjectingVisitor.Context context = new EqualityExtractor.ProxyInjectingVisitor.Context(
                columns, exact);
        Symbol proxiedTree = ProxyInjectingVisitor.INSTANCE.process(symbol, context);

        // bail out if we have any unknown part in the tree
        if (context.exact && context.seenUnknown) {
            return null;
        }

        List<Set<EqProxy>> comparisons = context.comparisonSet();
        Set<List<EqProxy>> cp = Sets.cartesianProduct(comparisons);

        List<List<Symbol>> result = new ArrayList<>();
        for (List<EqProxy> proxies : cp) {
            boolean anyNull = false;
            for (EqProxy proxy : proxies) {
                if (proxy != NULL_MARKER_PROXY) {
                    proxy.setTrue();
                } else {
                    anyNull = true;
                }
            }
            Symbol normalized = normalizer.normalize(proxiedTree, transactionContext);
            if (normalized == Literal.BOOLEAN_TRUE) {
                if (anyNull) {
                    return null;
                }
                if (!proxies.isEmpty()) {
                    List<Symbol> row = new ArrayList<>(proxies.size());
                    for (EqProxy proxy : proxies) {
                        proxy.reset();
                        row.add(proxy.origin.arguments().get(1));
                    }
                    result.add(row);
                }
            } else {
                for (EqProxy proxy : proxies) {
                    proxy.reset();
                }
            }
        }
        return result.isEmpty() ? null : result;

    }

    /**
     * Wraps any_= functions and exhibits the same logical semantics like
     * OR-chained comparisons.
     * <p>
     * creates EqProxies for every single any array element
     * and shares pre existing proxies - so true value can be set on
     * this "parent" if a shared comparison is "true"
     */
    private static class AnyEqProxy extends EqProxy implements Iterable<EqProxy> {

        private Map<Function, EqProxy> proxies;
        @Nullable
        private ChildEqProxy delegate = null;

        private AnyEqProxy(Function compared, Map<Function, EqProxy> existingProxies) {
            super(compared);
            initProxies(existingProxies);
        }

        private void initProxies(Map<Function, EqProxy> existingProxies) {
            Symbol left = origin.arguments().get(0);
            DataType leftType = origin.info().ident().argumentTypes().get(0);
            DataType rightType = ((CollectionType) origin.info().ident().argumentTypes().get(1)).innerType();
            FunctionInfo eqInfo = new FunctionInfo(
                    new FunctionIdent(EqOperator.NAME, ImmutableList.of(leftType, rightType)), DataTypes.BOOLEAN);
            Literal arrayLiteral = (Literal) origin.arguments().get(1);
            proxies = new HashMap<>();
            for (Literal arrayElem : Literal.explodeCollection(arrayLiteral)) {
                Function f = new Function(eqInfo, Arrays.asList(left, arrayElem));
                EqProxy existingProxy = existingProxies.get(f);
                if (existingProxy == null) {
                    existingProxy = new ChildEqProxy(f, this);
                } else if (existingProxy instanceof ChildEqProxy) {
                    ((ChildEqProxy) existingProxy).addParent(this);
                }
                proxies.put(f, existingProxy);
            }
        }

        @Override
        public Iterator<EqProxy> iterator() {
            return proxies.values().iterator();
        }

        @Override
        public <C, R> R accept(SymbolVisitor<C, R> visitor, C context) {
            if (delegate != null) {
                return delegate.accept(visitor, context);
            }
            return super.accept(visitor, context);
        }

        private void setDelegate(@Nullable ChildEqProxy childEqProxy) {
            delegate = childEqProxy;
        }

        private void cleanDelegate() {
            delegate = null;
        }

        private static class ChildEqProxy extends EqProxy {

            private List<AnyEqProxy> parentProxies = new ArrayList<>();

            private ChildEqProxy(Function origin, AnyEqProxy parent) {
                super(origin);
                this.addParent(parent);
            }

            private void addParent(AnyEqProxy parentProxy) {
                parentProxies.add(parentProxy);
            }

            @Override
            public void setTrue() {
                super.setTrue();
                for (AnyEqProxy parent : parentProxies) {
                    parent.setTrue();
                    parent.setDelegate(this);
                }
            }

            @Override
            public void reset() {
                super.reset();
                for (AnyEqProxy parent : parentProxies) {
                    parent.reset();
                    parent.cleanDelegate();
                }
            }
        }
    }

    static class EqProxy extends Symbol {

        protected Symbol current;
        protected final Function origin;

        private Function origin() {
            return origin;
        }

        EqProxy(Function origin) {
            this.origin = origin;
            this.current = origin;
        }

        public void reset() {
            this.current = origin;
        }

        public void setTrue() {
            this.current = Literal.BOOLEAN_TRUE;
        }

        @Override
        public SymbolType symbolType() {
            return current.symbolType();
        }

        @Override
        public <C, R> R accept(SymbolVisitor<C, R> visitor, C context) {
            return current.accept(visitor, context);
        }

        @Override
        public DataType valueType() {
            return current.valueType();
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            throw new UnsupportedOperationException();
        }

        public String forDisplay() {
            if (this == NULL_MARKER_PROXY) {
                return "NULL";
            }
            String s = "(" + ((Reference) origin.arguments().get(0)).ident().columnIdent().fqn() + "="
                    + ((Literal) origin.arguments().get(1)).value() + ")";
            if (current != origin) {
                s += " TRUE";
            }
            return s;
        }

        @Override
        public String toString() {
            return "EqProxy{" + forDisplay() + "}";
        }
    }

    static class ProxyInjectingVisitor extends SymbolVisitor<ProxyInjectingVisitor.Context, Symbol> {

        public static final ProxyInjectingVisitor INSTANCE = new ProxyInjectingVisitor();

        private ProxyInjectingVisitor() {
        }

        static class Comparison {
            final HashMap<Function, EqProxy> proxies = new HashMap<>();

            public Comparison() {
                proxies.put(NULL_MARKER, NULL_MARKER_PROXY);
            }

            public EqProxy add(Function compared) {
                if (compared.info().ident().name().equals(AnyEqOperator.NAME)) {
                    AnyEqProxy anyEqProxy = new AnyEqProxy(compared, proxies);
                    for (EqProxy proxiedProxy : anyEqProxy) {
                        if (!proxies.containsKey(proxiedProxy.origin())) {
                            proxies.put(proxiedProxy.origin(), proxiedProxy);
                        }
                    }
                    return anyEqProxy;
                }
                EqProxy proxy = proxies.get(compared);
                if (proxy == null) {
                    proxy = new EqProxy(compared);
                    proxies.put(compared, proxy);
                }
                return proxy;
            }

        }

        static class Context {

            private LinkedHashMap<ColumnIdent, Comparison> comparisons;
            private boolean proxyBelow;
            private boolean seenUnknown = false;
            private final boolean exact;

            private Context(Collection<ColumnIdent> references, boolean exact) {
                this.exact = exact;
                comparisons = new LinkedHashMap<>(references.size());
                for (ColumnIdent reference : references) {
                    comparisons.put(reference, new Comparison());
                }
            }

            private List<Set<EqProxy>> comparisonSet() {
                List<Set<EqProxy>> comps = new ArrayList<>(comparisons.size());
                for (Comparison comparison : comparisons.values()) {
                    // TODO: probably create a view instead of a seperate set
                    comps.add(new HashSet<>(comparison.proxies.values()));
                }
                return comps;
            }
        }

        @Override
        protected Symbol visitSymbol(Symbol symbol, Context context) {
            return symbol;
        }

        @Override
        public Symbol visitMatchPredicate(MatchPredicate matchPredicate, Context context) {
            context.seenUnknown = true;
            return Literal.BOOLEAN_TRUE;
        }

        @Override
        public Symbol visitReference(Reference symbol, Context context) {
            if (!context.comparisons.containsKey(symbol.ident().columnIdent())) {
                context.seenUnknown = true;
            }
            return super.visitReference(symbol, context);
        }

        public Symbol visitFunction(Function function, Context context) {

            String functionName = function.info().ident().name();
            if (functionName.equals(EqOperator.NAME)) {
                if (function.arguments().get(0) instanceof Reference) {
                    Comparison comparison = context.comparisons
                            .get(((Reference) function.arguments().get(0)).ident().columnIdent());
                    if (comparison != null) {
                        context.proxyBelow = true;
                        return comparison.add(function);
                    }
                }
            } else if (functionName.equals(AnyEqOperator.NAME)
                    && function.arguments().get(1).symbolType().isValueSymbol()) {
                // ref = any ([1,2,3])
                if (function.arguments().get(0) instanceof Reference) {
                    Reference reference = (Reference) function.arguments().get(0);
                    Comparison comparison = context.comparisons.get(reference.ident().columnIdent());
                    if (comparison != null) {
                        context.proxyBelow = true;
                        return comparison.add(function);
                    }
                }
            }
            boolean proxyBelowPre = context.proxyBelow;
            boolean proxyBelowPost = proxyBelowPre;
            ArrayList<Symbol> args = new ArrayList<>(function.arguments().size());
            for (Symbol arg : function.arguments()) {
                context.proxyBelow = proxyBelowPre;
                args.add(process(arg, context));
                proxyBelowPost = context.proxyBelow || proxyBelowPost;
            }
            context.proxyBelow = proxyBelowPost;
            if (!context.proxyBelow && function.valueType().equals(DataTypes.BOOLEAN)) {
                return Literal.BOOLEAN_TRUE;
            }
            return new Function(function.info(), args);
        }
    }
}