com.dangdang.ddframe.rdb.sharding.merger.groupby.GroupByResultSet.java Source code

Java tutorial

Introduction

Here is the source code for com.dangdang.ddframe.rdb.sharding.merger.groupby.GroupByResultSet.java

Source

/**
 * Copyright 1999-2015 dangdang.com.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * </p>
 */

package com.dangdang.ddframe.rdb.sharding.merger.groupby;

import java.math.BigDecimal;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import com.dangdang.ddframe.rdb.sharding.executor.ExecuteUnit;
import com.dangdang.ddframe.rdb.sharding.executor.ExecutorEngine;
import com.dangdang.ddframe.rdb.sharding.executor.MergeUnit;
import com.dangdang.ddframe.rdb.sharding.jdbc.AbstractShardingResultSet;
import com.dangdang.ddframe.rdb.sharding.merger.aggregation.AggregationUnit;
import com.dangdang.ddframe.rdb.sharding.merger.aggregation.AggregationUnitFactory;
import com.dangdang.ddframe.rdb.sharding.merger.common.ResultSetQueryIndex;
import com.dangdang.ddframe.rdb.sharding.merger.common.ResultSetUtil;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.AggregationColumn;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.GroupByColumn;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.MergeContext;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.OrderByColumn;
import com.google.common.base.Optional;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

/**
 * .
 * 
 * <p>
 * map-reduce?.
 * map-reduce?nextForSharding(), ?group-key,??order by??(shuffle).
 * </p>
 * 
 * @author gaohongtao, zhangliang
 */
@Slf4j
public final class GroupByResultSet extends AbstractShardingResultSet {

    private final List<GroupByColumn> groupByColumns;

    private final List<OrderByColumn> orderByColumns;

    private final List<AggregationColumn> aggregationColumns;

    private final ResultSetMetaData resultSetMetaData;

    private final List<String> columnLabels;

    private Iterator<GroupByValue> groupByResultIterator;

    @Getter(AccessLevel.PROTECTED)
    private GroupByValue currentGroupByResultSet;

    public GroupByResultSet(final List<ResultSet> resultSets, final MergeContext mergeContext) throws SQLException {
        super(resultSets, mergeContext.getLimit());
        groupByColumns = mergeContext.getGroupByColumns();
        orderByColumns = mergeContext.getOrderByColumns();
        aggregationColumns = mergeContext.getAggregationColumns();
        resultSetMetaData = getResultSets().iterator().next().getMetaData();
        columnLabels = new ArrayList<>(resultSetMetaData.getColumnCount());
        fillRelatedColumnNames();
    }

    private void fillRelatedColumnNames() throws SQLException {
        for (int i = 1; i < resultSetMetaData.getColumnCount() + 1; i++) {
            columnLabels.add(resultSetMetaData.getColumnLabel(i));
        }
    }

    @Override
    protected boolean nextForSharding() throws SQLException {
        if (null == groupByResultIterator) {
            ResultSetUtil.fillIndexesForDerivedAggregationColumns(getResultSets().iterator().next(),
                    aggregationColumns);
            groupByResultIterator = reduce(map()).iterator();
        }
        if (groupByResultIterator.hasNext()) {
            currentGroupByResultSet = groupByResultIterator.next();
            return true;
        } else {
            return false;
        }
    }

    private Multimap<GroupByKey, GroupByValue> map() throws SQLException {
        ExecuteUnit<ResultSet, Map<GroupByKey, GroupByValue>> executeUnit = new ExecuteUnit<ResultSet, Map<GroupByKey, GroupByValue>>() {

            @Override
            public Map<GroupByKey, GroupByValue> execute(final ResultSet resultSet) throws SQLException {
                // TODO ??limitresult?size??size??
                Map<GroupByKey, GroupByValue> result = new HashMap<>();
                while (resultSet.next()) {
                    GroupByValue groupByValue = new GroupByValue();
                    for (int count = 1; count <= columnLabels.size(); count++) {
                        groupByValue.put(count, resultSetMetaData.getColumnLabel(count),
                                (Comparable<?>) resultSet.getObject(count));
                    }
                    GroupByKey groupByKey = new GroupByKey();
                    for (GroupByColumn each : groupByColumns) {
                        groupByKey.append(ResultSetUtil.getValue(each, resultSet).toString());
                    }
                    result.put(groupByKey, groupByValue);
                }
                log.trace("Result set mapping: {}", result);
                return result;
            }
        };
        MergeUnit<Map<GroupByKey, GroupByValue>, Multimap<GroupByKey, GroupByValue>> mergeUnit = new MergeUnit<Map<GroupByKey, GroupByValue>, Multimap<GroupByKey, GroupByValue>>() {

            @Override
            public Multimap<GroupByKey, GroupByValue> merge(final List<Map<GroupByKey, GroupByValue>> values) {
                Multimap<GroupByKey, GroupByValue> result = HashMultimap.create();
                for (Map<GroupByKey, GroupByValue> each : values) {
                    for (Entry<GroupByKey, GroupByValue> entry : each.entrySet()) {
                        result.put(entry.getKey(), entry.getValue());
                    }
                }
                return result;
            }
        };
        Multimap<GroupByKey, GroupByValue> result = ExecutorEngine.execute(getResultSets(), executeUnit, mergeUnit);
        log.trace("Mapped result: {}", result);
        return result;
    }

    private Collection<GroupByValue> reduce(final Multimap<GroupByKey, GroupByValue> mappedResult)
            throws SQLException {
        List<GroupByValue> result = new ArrayList<>(mappedResult.values().size() * columnLabels.size());
        for (GroupByKey key : mappedResult.keySet()) {
            Collection<GroupByValue> each = mappedResult.get(key);
            GroupByValue reduceResult = new GroupByValue();
            for (int i = 0; i < columnLabels.size(); i++) {
                int index = i + 1;
                Optional<AggregationColumn> aggregationColumn = findAggregationColumn(index);
                Comparable<?> value = null;
                if (aggregationColumn.isPresent()) {
                    value = aggregate(aggregationColumn.get(), index, each);
                }
                value = null == value ? each.iterator().next().getValue(new ResultSetQueryIndex(index)) : value;
                reduceResult.put(index, columnLabels.get(i), value);
            }
            if (orderByColumns.isEmpty()) {
                reduceResult.addGroupByColumns(groupByColumns);
            } else {
                reduceResult.addOrderColumns(orderByColumns);
            }
            result.add(reduceResult);
        }
        Collections.sort(result);
        log.trace("Reduced result: {}", result);
        return result;
    }

    private Optional<AggregationColumn> findAggregationColumn(final int index) {
        for (AggregationColumn each : aggregationColumns) {
            if (each.getIndex() == index) {
                return Optional.of(each);
            }
        }
        return Optional.absent();
    }

    private Comparable<?> aggregate(final AggregationColumn column, final int index,
            final Collection<GroupByValue> groupByValues) throws SQLException {
        AggregationUnit unit = AggregationUnitFactory.create(column.getAggregationType(), BigDecimal.class);
        for (GroupByValue each : groupByValues) {
            unit.merge(column, each, new ResultSetQueryIndex(index));
        }
        return unit.getResult();
    }
}