org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer.java

Source

package org.apache.solr.search.grouping.distributed.shardresultserializer;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.UnicodeUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.Command;
import org.apache.solr.search.grouping.distributed.command.QueryCommand;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList} structure and
 * visa versa.
 */
public class TopGroupsResultTransformer implements ShardResultTransformer<List<Command>, Map<String, ?>> {

    private final ResponseBuilder rb;

    private static final Logger log = LoggerFactory.getLogger(TopGroupsResultTransformer.class);

    public TopGroupsResultTransformer(ResponseBuilder rb) {
        this.rb = rb;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public NamedList transform(List<Command> data) throws IOException {
        NamedList<NamedList> result = new NamedList<NamedList>();
        final IndexSchema schema = rb.req.getSearcher().getSchema();
        for (Command command : data) {
            NamedList commandResult;
            if (TopGroupsFieldCommand.class.isInstance(command)) {
                TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command;
                SchemaField groupField = schema.getField(fieldCommand.getKey());
                commandResult = serializeTopGroups(fieldCommand.result(), groupField);
            } else if (QueryCommand.class.isInstance(command)) {
                QueryCommand queryCommand = (QueryCommand) command;
                commandResult = serializeTopDocs(queryCommand.result());
            } else {
                commandResult = null;
            }

            result.add(command.getKey(), commandResult);
        }
        return result;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Map<String, ?> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort,
            Sort sortWithinGroup, String shard) {
        Map<String, Object> result = new HashMap<String, Object>();

        for (Map.Entry<String, NamedList> entry : shardResponse) {
            String key = entry.getKey();
            NamedList commandResult = entry.getValue();
            Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount");
            Integer totalHits = (Integer) commandResult.get("totalHits");
            if (totalHits != null) {
                Integer matches = (Integer) commandResult.get("matches");
                Float maxScore = (Float) commandResult.get("maxScore");
                if (maxScore == null) {
                    maxScore = Float.NaN;
                }

                @SuppressWarnings("unchecked")
                List<NamedList<Object>> documents = (List<NamedList<Object>>) commandResult.get("documents");
                ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
                int j = 0;
                for (NamedList<Object> document : documents) {
                    Object docId = document.get("id");
                    Object uniqueId = null;
                    if (docId != null)
                        uniqueId = docId.toString();
                    else
                        log.warn("doc {} has null 'id'", document);
                    Float score = (Float) document.get("score");
                    if (score == null) {
                        score = Float.NaN;
                    }
                    Object[] sortValues = null;
                    Object sortValuesVal = document.get("sortValues");
                    if (sortValuesVal != null) {
                        sortValues = ((List) sortValuesVal).toArray();
                    } else {
                        log.warn("doc {} has null 'sortValues'", document);
                    }
                    scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
                }
                result.put(key, new QueryCommandResult(new TopDocs(totalHits, scoreDocs, maxScore), matches));
                continue;
            }

            Integer totalHitCount = (Integer) commandResult.get("totalHitCount");

            List<GroupDocs<BytesRef>> groupDocs = new ArrayList<GroupDocs<BytesRef>>();
            for (int i = 2; i < commandResult.size(); i++) {
                String groupValue = commandResult.getName(i);
                @SuppressWarnings("unchecked")
                NamedList<Object> groupResult = (NamedList<Object>) commandResult.getVal(i);
                Integer totalGroupHits = (Integer) groupResult.get("totalHits");
                Float maxScore = (Float) groupResult.get("maxScore");
                if (maxScore == null) {
                    maxScore = Float.NaN;
                }

                @SuppressWarnings("unchecked")
                List<NamedList<Object>> documents = (List<NamedList<Object>>) groupResult.get("documents");
                ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
                int j = 0;
                for (NamedList<Object> document : documents) {
                    Object uniqueId = document.get("id").toString();
                    Float score = (Float) document.get("score");
                    if (score == null) {
                        score = Float.NaN;
                    }
                    Object[] sortValues = ((List) document.get("sortValues")).toArray();
                    scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
                }

                BytesRef groupValueRef = groupValue != null ? new BytesRef(groupValue) : null;
                groupDocs.add(new GroupDocs<BytesRef>(Float.NaN, maxScore, totalGroupHits, scoreDocs, groupValueRef,
                        null));
            }

            @SuppressWarnings("unchecked")
            GroupDocs<BytesRef>[] groupDocsArr = groupDocs.toArray(new GroupDocs[groupDocs.size()]);
            TopGroups<BytesRef> topGroups = new TopGroups<BytesRef>(groupSort.getSort(), sortWithinGroup.getSort(),
                    totalHitCount, totalGroupedHitCount, groupDocsArr, Float.NaN);

            result.put(key, topGroups);
        }

        return result;
    }

    protected NamedList serializeTopGroups(TopGroups<BytesRef> data, SchemaField groupField) throws IOException {
        NamedList<Object> result = new NamedList<Object>();
        result.add("totalGroupedHitCount", data.totalGroupedHitCount);
        result.add("totalHitCount", data.totalHitCount);
        if (data.totalGroupCount != null) {
            result.add("totalGroupCount", data.totalGroupCount);
        }
        CharsRef spare = new CharsRef();

        final IndexSchema schema = rb.req.getSearcher().getSchema();
        SchemaField uniqueField = schema.getUniqueKeyField();
        for (GroupDocs<BytesRef> searchGroup : data.groups) {
            NamedList<Object> groupResult = new NamedList<Object>();
            groupResult.add("totalHits", searchGroup.totalHits);
            if (!Float.isNaN(searchGroup.maxScore)) {
                groupResult.add("maxScore", searchGroup.maxScore);
            }

            List<NamedList<Object>> documents = new ArrayList<NamedList<Object>>();
            for (int i = 0; i < searchGroup.scoreDocs.length; i++) {
                NamedList<Object> document = new NamedList<Object>();
                documents.add(document);

                Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc);
                document.add("id", uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
                if (!Float.isNaN(searchGroup.scoreDocs[i].score)) {
                    document.add("score", searchGroup.scoreDocs[i].score);
                }
                if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) {
                    continue;
                }

                FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i];
                Object[] convertedSortValues = new Object[fieldDoc.fields.length];
                for (int j = 0; j < fieldDoc.fields.length; j++) {
                    Object sortValue = fieldDoc.fields[j];
                    Sort sortWithinGroup = rb.getGroupingSpec().getSortWithinGroup();
                    SchemaField field = sortWithinGroup.getSort()[j].getField() != null
                            ? schema.getFieldOrNull(sortWithinGroup.getSort()[j].getField())
                            : null;
                    if (field != null) {
                        FieldType fieldType = field.getType();
                        if (sortValue instanceof BytesRef) {
                            UnicodeUtil.UTF8toUTF16((BytesRef) sortValue, spare);
                            String indexedValue = spare.toString();
                            sortValue = fieldType
                                    .toObject(field.createField(fieldType.indexedToReadable(indexedValue), 1.0f));
                        } else if (sortValue instanceof String) {
                            sortValue = fieldType.toObject(
                                    field.createField(fieldType.indexedToReadable((String) sortValue), 1.0f));
                        }
                    }
                    convertedSortValues[j] = sortValue;
                }
                document.add("sortValues", convertedSortValues);
            }
            groupResult.add("documents", documents);
            String groupValue = searchGroup.groupValue != null
                    ? groupField.getType().indexedToReadable(searchGroup.groupValue.utf8ToString())
                    : null;
            result.add(groupValue, groupResult);
        }

        return result;
    }

    protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException {
        NamedList<Object> queryResult = new NamedList<Object>();
        queryResult.add("matches", result.getMatches());
        queryResult.add("totalHits", result.getTopDocs().totalHits);
        if (rb.getGroupingSpec().isNeedScore()) {
            queryResult.add("maxScore", result.getTopDocs().getMaxScore());
        }
        List<NamedList> documents = new ArrayList<NamedList>();
        queryResult.add("documents", documents);

        final IndexSchema schema = rb.req.getSearcher().getSchema();
        SchemaField uniqueField = schema.getUniqueKeyField();
        CharsRef spare = new CharsRef();
        for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) {
            NamedList<Object> document = new NamedList<Object>();
            documents.add(document);

            Document doc = retrieveDocument(uniqueField, scoreDoc.doc);
            document.add("id", uniqueField.getType().toExternal(doc.getField(uniqueField.getName())));
            if (rb.getGroupingSpec().isNeedScore()) {
                document.add("score", scoreDoc.score);
            }
            if (!FieldDoc.class.isInstance(scoreDoc)) {
                continue;
            }

            FieldDoc fieldDoc = (FieldDoc) scoreDoc;
            Object[] convertedSortValues = new Object[fieldDoc.fields.length];
            for (int j = 0; j < fieldDoc.fields.length; j++) {
                Object sortValue = fieldDoc.fields[j];
                Sort groupSort = rb.getGroupingSpec().getGroupSort();
                SchemaField field = groupSort.getSort()[j].getField() != null
                        ? schema.getFieldOrNull(groupSort.getSort()[j].getField())
                        : null;
                if (field != null) {
                    FieldType fieldType = field.getType();
                    if (sortValue instanceof BytesRef) {
                        UnicodeUtil.UTF8toUTF16((BytesRef) sortValue, spare);
                        String indexedValue = spare.toString();
                        sortValue = fieldType
                                .toObject(field.createField(fieldType.indexedToReadable(indexedValue), 1.0f));
                    } else if (sortValue instanceof String) {
                        sortValue = fieldType
                                .toObject(field.createField(fieldType.indexedToReadable((String) sortValue), 1.0f));
                    }
                }
                convertedSortValues[j] = sortValue;
            }
            document.add("sortValues", convertedSortValues);
        }

        return queryResult;
    }

    private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException {
        DocumentStoredFieldVisitor visitor = new DocumentStoredFieldVisitor(uniqueField.getName());
        rb.req.getSearcher().doc(doc, visitor);
        return visitor.getDocument();
    }

}