org.apache.cassandra.hadoop.cql3.CqlRecordReader.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.cassandra.hadoop.cql3.CqlRecordReader.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.hadoop.cql3;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.util.*;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ColumnDefinitions;
import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.LocalDate;
import com.datastax.driver.core.Metadata;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.TableMetadata;
import com.datastax.driver.core.Token;
import com.datastax.driver.core.TupleValue;
import com.datastax.driver.core.UDTValue;
import com.google.common.reflect.TypeToken;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.hadoop.ColumnFamilySplit;
import org.apache.cassandra.hadoop.ConfigHelper;
import org.apache.cassandra.hadoop.HadoopCompat;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;

/**
 * <p>
 * CqlRecordReader reads the rows return from the CQL query
 * It uses CQL auto-paging.
 * </p>
 * <p>
 * Return a Long as a local CQL row key starts from 0;
 * </p>
 * {@code
 * Row as C* java driver CQL result set row
 * 1) select clause must include partition key columns (to calculate the progress based on the actual CF row processed)
 * 2) where clause must include token(partition_key1, ...  , partition_keyn) > ? and 
 *       token(partition_key1, ... , partition_keyn) <= ?  (in the right order) 
 * }
 */
public class CqlRecordReader extends RecordReader<Long, Row>
        implements org.apache.hadoop.mapred.RecordReader<Long, Row>, AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(CqlRecordReader.class);

    private ColumnFamilySplit split;
    private RowIterator rowIterator;

    private Pair<Long, Row> currentRow;
    private int totalRowCount; // total number of rows to fetch
    private String keyspace;
    private String cfName;
    private String cqlQuery;
    private Cluster cluster;
    private Session session;
    private IPartitioner partitioner;
    private String inputColumns;
    private String userDefinedWhereClauses;

    private List<String> partitionKeys = new ArrayList<>();

    // partition keys -- key aliases
    private LinkedHashMap<String, Boolean> partitionBoundColumns = Maps.newLinkedHashMap();
    protected int nativeProtocolVersion = 1;

    public CqlRecordReader() {
        super();
    }

    @Override
    public void initialize(InputSplit split, TaskAttemptContext context) throws IOException {
        this.split = (ColumnFamilySplit) split;
        Configuration conf = HadoopCompat.getConfiguration(context);
        totalRowCount = (this.split.getLength() < Long.MAX_VALUE) ? (int) this.split.getLength()
                : ConfigHelper.getInputSplitSize(conf);
        cfName = ConfigHelper.getInputColumnFamily(conf);
        keyspace = ConfigHelper.getInputKeyspace(conf);
        partitioner = ConfigHelper.getInputPartitioner(conf);
        inputColumns = CqlConfigHelper.getInputcolumns(conf);
        userDefinedWhereClauses = CqlConfigHelper.getInputWhereClauses(conf);

        try {
            if (cluster != null)
                return;

            // create a Cluster instance
            String[] locations = split.getLocations();
            cluster = CqlConfigHelper.getInputCluster(locations, conf);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        if (cluster != null)
            session = cluster.connect(quote(keyspace));

        if (session == null)
            throw new RuntimeException("Can't create connection session");

        //get negotiated serialization protocol
        nativeProtocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion().toInt();

        // If the user provides a CQL query then we will use it without validation
        // otherwise we will fall back to building a query using the:
        //   inputColumns
        //   whereClauses
        cqlQuery = CqlConfigHelper.getInputCql(conf);
        // validate that the user hasn't tried to give us a custom query along with input columns
        // and where clauses
        if (StringUtils.isNotEmpty(cqlQuery)
                && (StringUtils.isNotEmpty(inputColumns) || StringUtils.isNotEmpty(userDefinedWhereClauses))) {
            throw new AssertionError("Cannot define a custom query with input columns and / or where clauses");
        }

        if (StringUtils.isEmpty(cqlQuery))
            cqlQuery = buildQuery();
        logger.trace("cqlQuery {}", cqlQuery);

        rowIterator = new RowIterator();
        logger.trace("created {}", rowIterator);
    }

    public void close() {
        if (session != null)
            session.close();
        if (cluster != null)
            cluster.close();
    }

    public Long getCurrentKey() {
        return currentRow.left;
    }

    public Row getCurrentValue() {
        return currentRow.right;
    }

    public float getProgress() {
        if (!rowIterator.hasNext())
            return 1.0F;

        // the progress is likely to be reported slightly off the actual but close enough
        float progress = ((float) rowIterator.totalRead / totalRowCount);
        return progress > 1.0F ? 1.0F : progress;
    }

    public boolean nextKeyValue() throws IOException {
        if (!rowIterator.hasNext()) {
            logger.trace("Finished scanning {} rows (estimate was: {})", rowIterator.totalRead, totalRowCount);
            return false;
        }

        try {
            currentRow = rowIterator.next();
        } catch (Exception e) {
            // throw it as IOException, so client can catch it and handle it at client side
            IOException ioe = new IOException(e.getMessage());
            ioe.initCause(ioe.getCause());
            throw ioe;
        }
        return true;
    }

    // Because the old Hadoop API wants us to write to the key and value
    // and the new asks for them, we need to copy the output of the new API
    // to the old. Thus, expect a small performance hit.
    // And obviously this wouldn't work for wide rows. But since ColumnFamilyInputFormat
    // and ColumnFamilyRecordReader don't support them, it should be fine for now.
    public boolean next(Long key, Row value) throws IOException {
        if (nextKeyValue()) {
            ((WrappedRow) value).setRow(getCurrentValue());
            return true;
        }
        return false;
    }

    public long getPos() throws IOException {
        return rowIterator.totalRead;
    }

    public Long createKey() {
        return Long.valueOf(0L);
    }

    public Row createValue() {
        return new WrappedRow();
    }

    /**
     * Return native version protocol of the cluster connection
     * @return serialization protocol version.
     */
    public int getNativeProtocolVersion() {
        return nativeProtocolVersion;
    }

    /** CQL row iterator 
     *  Input cql query  
     *  1) select clause must include key columns (if we use partition key based row count)
     *  2) where clause must include token(partition_key1 ... partition_keyn) > ? and 
     *     token(partition_key1 ... partition_keyn) <= ? 
     */
    private class RowIterator extends AbstractIterator<Pair<Long, Row>> {
        private long keyId = 0L;
        protected int totalRead = 0; // total number of cf rows read
        protected Iterator<Row> rows;
        private Map<String, ByteBuffer> previousRowKey = new HashMap<String, ByteBuffer>(); // previous CF row key

        public RowIterator() {
            AbstractType type = partitioner.getTokenValidator();
            ResultSet rs = session.execute(cqlQuery, type.compose(type.fromString(split.getStartToken())),
                    type.compose(type.fromString(split.getEndToken())));
            for (ColumnMetadata meta : cluster.getMetadata().getKeyspace(quote(keyspace)).getTable(quote(cfName))
                    .getPartitionKey())
                partitionBoundColumns.put(meta.getName(), Boolean.TRUE);
            rows = rs.iterator();
        }

        protected Pair<Long, Row> computeNext() {
            if (rows == null || !rows.hasNext())
                return endOfData();

            Row row = rows.next();
            Map<String, ByteBuffer> keyColumns = new HashMap<String, ByteBuffer>(partitionBoundColumns.size());
            for (String column : partitionBoundColumns.keySet())
                keyColumns.put(column, row.getBytesUnsafe(column));

            // increase total CF row read
            if (previousRowKey.isEmpty() && !keyColumns.isEmpty()) {
                previousRowKey = keyColumns;
                totalRead++;
            } else {
                for (String column : partitionBoundColumns.keySet()) {
                    // this is not correct - but we don't seem to have easy access to better type information here
                    if (ByteBufferUtil.compareUnsigned(keyColumns.get(column), previousRowKey.get(column)) != 0) {
                        previousRowKey = keyColumns;
                        totalRead++;
                        break;
                    }
                }
            }
            keyId++;
            return Pair.create(keyId, row);
        }
    }

    private static class WrappedRow implements Row {
        private Row row;

        public void setRow(Row row) {
            this.row = row;
        }

        @Override
        public ColumnDefinitions getColumnDefinitions() {
            return row.getColumnDefinitions();
        }

        @Override
        public boolean isNull(int i) {
            return row.isNull(i);
        }

        @Override
        public boolean isNull(String name) {
            return row.isNull(name);
        }

        @Override
        public Object getObject(int i) {
            return row.getObject(i);
        }

        @Override
        public Object getObject(String s) {
            return row.getObject(s);
        }

        @Override
        public boolean getBool(int i) {
            return row.getBool(i);
        }

        @Override
        public boolean getBool(String name) {
            return row.getBool(name);
        }

        @Override
        public short getShort(int i) {
            return row.getShort(i);
        }

        @Override
        public short getShort(String s) {
            return row.getShort(s);
        }

        @Override
        public byte getByte(int i) {
            return row.getByte(i);
        }

        @Override
        public byte getByte(String s) {
            return row.getByte(s);
        }

        @Override
        public int getInt(int i) {
            return row.getInt(i);
        }

        @Override
        public int getInt(String name) {
            return row.getInt(name);
        }

        @Override
        public long getLong(int i) {
            return row.getLong(i);
        }

        @Override
        public long getLong(String name) {
            return row.getLong(name);
        }

        @Override
        public Date getTimestamp(int i) {
            return row.getTimestamp(i);
        }

        @Override
        public Date getTimestamp(String s) {
            return row.getTimestamp(s);
        }

        @Override
        public LocalDate getDate(int i) {
            return row.getDate(i);
        }

        @Override
        public LocalDate getDate(String s) {
            return row.getDate(s);
        }

        @Override
        public long getTime(int i) {
            return row.getTime(i);
        }

        @Override
        public long getTime(String s) {
            return row.getTime(s);
        }

        @Override
        public float getFloat(int i) {
            return row.getFloat(i);
        }

        @Override
        public float getFloat(String name) {
            return row.getFloat(name);
        }

        @Override
        public double getDouble(int i) {
            return row.getDouble(i);
        }

        @Override
        public double getDouble(String name) {
            return row.getDouble(name);
        }

        @Override
        public ByteBuffer getBytesUnsafe(int i) {
            return row.getBytesUnsafe(i);
        }

        @Override
        public ByteBuffer getBytesUnsafe(String name) {
            return row.getBytesUnsafe(name);
        }

        @Override
        public ByteBuffer getBytes(int i) {
            return row.getBytes(i);
        }

        @Override
        public ByteBuffer getBytes(String name) {
            return row.getBytes(name);
        }

        @Override
        public String getString(int i) {
            return row.getString(i);
        }

        @Override
        public String getString(String name) {
            return row.getString(name);
        }

        @Override
        public BigInteger getVarint(int i) {
            return row.getVarint(i);
        }

        @Override
        public BigInteger getVarint(String name) {
            return row.getVarint(name);
        }

        @Override
        public BigDecimal getDecimal(int i) {
            return row.getDecimal(i);
        }

        @Override
        public BigDecimal getDecimal(String name) {
            return row.getDecimal(name);
        }

        @Override
        public UUID getUUID(int i) {
            return row.getUUID(i);
        }

        @Override
        public UUID getUUID(String name) {
            return row.getUUID(name);
        }

        @Override
        public InetAddress getInet(int i) {
            return row.getInet(i);
        }

        @Override
        public InetAddress getInet(String name) {
            return row.getInet(name);
        }

        @Override
        public <T> List<T> getList(int i, Class<T> elementsClass) {
            return row.getList(i, elementsClass);
        }

        @Override
        public <T> List<T> getList(int i, TypeToken<T> typeToken) {
            return row.getList(i, typeToken);
        }

        @Override
        public <T> List<T> getList(String name, Class<T> elementsClass) {
            return row.getList(name, elementsClass);
        }

        @Override
        public <T> List<T> getList(String s, TypeToken<T> typeToken) {
            return row.getList(s, typeToken);
        }

        @Override
        public <T> Set<T> getSet(int i, Class<T> elementsClass) {
            return row.getSet(i, elementsClass);
        }

        @Override
        public <T> Set<T> getSet(int i, TypeToken<T> typeToken) {
            return row.getSet(i, typeToken);
        }

        @Override
        public <T> Set<T> getSet(String name, Class<T> elementsClass) {
            return row.getSet(name, elementsClass);
        }

        @Override
        public <T> Set<T> getSet(String s, TypeToken<T> typeToken) {
            return row.getSet(s, typeToken);
        }

        @Override
        public <K, V> Map<K, V> getMap(int i, Class<K> keysClass, Class<V> valuesClass) {
            return row.getMap(i, keysClass, valuesClass);
        }

        @Override
        public <K, V> Map<K, V> getMap(int i, TypeToken<K> typeToken, TypeToken<V> typeToken1) {
            return row.getMap(i, typeToken, typeToken1);
        }

        @Override
        public <K, V> Map<K, V> getMap(String name, Class<K> keysClass, Class<V> valuesClass) {
            return row.getMap(name, keysClass, valuesClass);
        }

        @Override
        public <K, V> Map<K, V> getMap(String s, TypeToken<K> typeToken, TypeToken<V> typeToken1) {
            return row.getMap(s, typeToken, typeToken1);
        }

        @Override
        public UDTValue getUDTValue(int i) {
            return row.getUDTValue(i);
        }

        @Override
        public UDTValue getUDTValue(String name) {
            return row.getUDTValue(name);
        }

        @Override
        public TupleValue getTupleValue(int i) {
            return row.getTupleValue(i);
        }

        @Override
        public TupleValue getTupleValue(String name) {
            return row.getTupleValue(name);
        }

        @Override
        public Token getToken(int i) {
            return row.getToken(i);
        }

        @Override
        public Token getToken(String name) {
            return row.getToken(name);
        }

        @Override
        public Token getPartitionKeyToken() {
            return row.getPartitionKeyToken();
        }
    }

    /**
     * Build a query for the reader of the form:
     *
     * SELECT * FROM ks>cf token(pk1,...pkn)>? AND token(pk1,...pkn)<=? [AND user where clauses] [ALLOW FILTERING]
     */
    private String buildQuery() {
        fetchKeys();

        List<String> columns = getSelectColumns();
        String selectColumnList = columns.size() == 0 ? "*" : makeColumnList(columns);
        String partitionKeyList = makeColumnList(partitionKeys);

        return String.format(
                "SELECT %s FROM %s.%s WHERE token(%s)>? AND token(%s)<=?" + getAdditionalWhereClauses(),
                selectColumnList, quote(keyspace), quote(cfName), partitionKeyList, partitionKeyList);
    }

    private String getAdditionalWhereClauses() {
        String whereClause = "";
        if (StringUtils.isNotEmpty(userDefinedWhereClauses))
            whereClause += " AND " + userDefinedWhereClauses;
        if (StringUtils.isNotEmpty(userDefinedWhereClauses))
            whereClause += " ALLOW FILTERING";
        return whereClause;
    }

    private List<String> getSelectColumns() {
        List<String> selectColumns = new ArrayList<>();

        if (StringUtils.isNotEmpty(inputColumns)) {
            // We must select all the partition keys plus any other columns the user wants
            selectColumns.addAll(partitionKeys);
            for (String column : Splitter.on(',').split(inputColumns)) {
                if (!partitionKeys.contains(column))
                    selectColumns.add(column);
            }
        }
        return selectColumns;
    }

    private String makeColumnList(Collection<String> columns) {
        return Joiner.on(',').join(Iterables.transform(columns, new Function<String, String>() {
            public String apply(String column) {
                return quote(column);
            }
        }));
    }

    private void fetchKeys() {
        // get CF meta data
        TableMetadata tableMetadata = session.getCluster().getMetadata().getKeyspace(Metadata.quote(keyspace))
                .getTable(Metadata.quote(cfName));
        if (tableMetadata == null) {
            throw new RuntimeException("No table metadata found for " + keyspace + "." + cfName);
        }
        //Here we assume that tableMetadata.getPartitionKey() always
        //returns the list of columns in order of component_index
        for (ColumnMetadata partitionKey : tableMetadata.getPartitionKey()) {
            partitionKeys.add(partitionKey.getName());
        }
    }

    private String quote(String identifier) {
        return "\"" + identifier.replaceAll("\"", "\"\"") + "\"";
    }
}