Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.orc; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.io.FileInputFormat; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.core.fs.Path; import org.apache.flink.types.Row; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; import org.apache.orc.Reader; import org.apache.orc.RecordReader; import org.apache.orc.StripeInformation; import org.apache.orc.TypeDescription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import static org.apache.flink.orc.OrcBatchReader.fillRows; /** * InputFormat to read ORC files. */ public class OrcRowInputFormat extends FileInputFormat<Row> implements ResultTypeQueryable<Row> { private static final Logger LOG = LoggerFactory.getLogger(OrcRowInputFormat.class); // the number of rows read in a batch private static final int DEFAULT_BATCH_SIZE = 1000; // the number of fields rows to read in a batch private int batchSize; // the configuration to read with private Configuration conf; // the schema of the ORC files to read private TypeDescription schema; // the fields of the ORC schema that the returned Rows are composed of. private int[] selectedFields; // the type information of the Rows returned by this InputFormat. private transient RowTypeInfo rowType; // the ORC reader private transient RecordReader orcRowsReader; // the vectorized row data to be read in a batch private transient VectorizedRowBatch rowBatch; // the vector of rows that is read in a batch private transient Row[] rows; // the number of rows in the current batch private transient int rowsInBatch; // the index of the next row to return private transient int nextRow; private ArrayList<Predicate> conjunctPredicates = new ArrayList<>(); /** * Creates an OrcRowInputFormat. * * @param path The path to read ORC files from. * @param schemaString The schema of the ORC files as String. * @param orcConfig The configuration to read the ORC files with. */ public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig) { this(path, TypeDescription.fromString(schemaString), orcConfig, DEFAULT_BATCH_SIZE); } /** * Creates an OrcRowInputFormat. * * @param path The path to read ORC files from. * @param schemaString The schema of the ORC files as String. * @param orcConfig The configuration to read the ORC files with. * @param batchSize The number of Row objects to read in a batch. */ public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig, int batchSize) { this(path, TypeDescription.fromString(schemaString), orcConfig, batchSize); } /** * Creates an OrcRowInputFormat. * * @param path The path to read ORC files from. * @param orcSchema The schema of the ORC files as ORC TypeDescription. * @param orcConfig The configuration to read the ORC files with. * @param batchSize The number of Row objects to read in a batch. */ public OrcRowInputFormat(String path, TypeDescription orcSchema, Configuration orcConfig, int batchSize) { super(new Path(path)); // configure OrcRowInputFormat this.schema = orcSchema; this.rowType = (RowTypeInfo) OrcBatchReader.schemaToTypeInfo(schema); this.conf = orcConfig; this.batchSize = batchSize; // set default selection mask, i.e., all fields. this.selectedFields = new int[this.schema.getChildren().size()]; for (int i = 0; i < selectedFields.length; i++) { this.selectedFields[i] = i; } } /** * Adds a filter predicate to reduce the number of rows to be returned by the input format. * Multiple conjunctive predicates can be added by calling this method multiple times. * * <p>Note: Predicates can significantly reduce the amount of data that is read. * However, the OrcRowInputFormat does not guarantee that all returned rows qualify the * predicates. Moreover, predicates are only applied if the referenced field is among the * selected fields. * * @param predicate The filter predicate. */ public void addPredicate(Predicate predicate) { // validate validatePredicate(predicate); // add predicate this.conjunctPredicates.add(predicate); } private void validatePredicate(Predicate pred) { if (pred instanceof ColumnPredicate) { // check column name String colName = ((ColumnPredicate) pred).columnName; if (!this.schema.getFieldNames().contains(colName)) { throw new IllegalArgumentException( "Predicate cannot be applied. " + "Column '" + colName + "' does not exist in ORC schema."); } } else if (pred instanceof Not) { validatePredicate(((Not) pred).child()); } else if (pred instanceof Or) { for (Predicate p : ((Or) pred).children()) { validatePredicate(p); } } } /** * Selects the fields from the ORC schema that are returned by InputFormat. * * @param selectedFields The indices of the fields of the ORC schema that are returned by the InputFormat. */ public void selectFields(int... selectedFields) { // set field mapping this.selectedFields = selectedFields; // adapt result type this.rowType = RowTypeInfo.projectFields(this.rowType, selectedFields); } /** * Computes the ORC projection mask of the fields to include from the selected fields.rowOrcInputFormat.nextRecord(null). * * @return The ORC projection mask. */ private boolean[] computeProjectionMask() { // mask with all fields of the schema boolean[] projectionMask = new boolean[schema.getMaximumId() + 1]; // for each selected field for (int inIdx : selectedFields) { // set all nested fields of a selected field to true TypeDescription fieldSchema = schema.getChildren().get(inIdx); for (int i = fieldSchema.getId(); i <= fieldSchema.getMaximumId(); i++) { projectionMask[i] = true; } } return projectionMask; } @Override public void openInputFormat() throws IOException { super.openInputFormat(); // create and initialize the row batch this.rows = new Row[batchSize]; for (int i = 0; i < batchSize; i++) { rows[i] = new Row(selectedFields.length); } } @Override public void open(FileInputSplit fileSplit) throws IOException { LOG.debug("Opening ORC file {}", fileSplit.getPath()); // open ORC file and create reader org.apache.hadoop.fs.Path hPath = new org.apache.hadoop.fs.Path(fileSplit.getPath().getPath()); Reader orcReader = OrcFile.createReader(hPath, OrcFile.readerOptions(conf)); // get offset and length for the stripes that start in the split Tuple2<Long, Long> offsetAndLength = getOffsetAndLengthForSplit(fileSplit, getStripes(orcReader)); // create ORC row reader configuration Reader.Options options = getOptions(orcReader).schema(schema).range(offsetAndLength.f0, offsetAndLength.f1) .useZeroCopy(OrcConf.USE_ZEROCOPY.getBoolean(conf)) .skipCorruptRecords(OrcConf.SKIP_CORRUPT_DATA.getBoolean(conf)) .tolerateMissingSchema(OrcConf.TOLERATE_MISSING_SCHEMA.getBoolean(conf)); // configure filters if (!conjunctPredicates.isEmpty()) { SearchArgument.Builder b = SearchArgumentFactory.newBuilder(); b = b.startAnd(); for (Predicate predicate : conjunctPredicates) { predicate.add(b); } b = b.end(); options.searchArgument(b.build(), new String[] {}); } // configure selected fields options.include(computeProjectionMask()); // create ORC row reader this.orcRowsReader = orcReader.rows(options); // assign ids this.schema.getId(); // create row batch this.rowBatch = schema.createRowBatch(batchSize); rowsInBatch = 0; nextRow = 0; } @VisibleForTesting Reader.Options getOptions(Reader orcReader) { return orcReader.options(); } @VisibleForTesting List<StripeInformation> getStripes(Reader orcReader) { return orcReader.getStripes(); } private Tuple2<Long, Long> getOffsetAndLengthForSplit(FileInputSplit split, List<StripeInformation> stripes) { long splitStart = split.getStart(); long splitEnd = splitStart + split.getLength(); long readStart = Long.MAX_VALUE; long readEnd = Long.MIN_VALUE; for (StripeInformation s : stripes) { if (splitStart <= s.getOffset() && s.getOffset() < splitEnd) { // stripe starts in split, so it is included readStart = Math.min(readStart, s.getOffset()); readEnd = Math.max(readEnd, s.getOffset() + s.getLength()); } } if (readStart < Long.MAX_VALUE) { // at least one split is included return Tuple2.of(readStart, readEnd - readStart); } else { return Tuple2.of(0L, 0L); } } @Override public void close() throws IOException { if (orcRowsReader != null) { this.orcRowsReader.close(); } this.orcRowsReader = null; } @Override public void closeInputFormat() throws IOException { this.rows = null; this.schema = null; this.rowBatch = null; } @Override public boolean reachedEnd() throws IOException { return !ensureBatch(); } /** * Checks if there is at least one row left in the batch to return. * If no more row are available, it reads another batch of rows. * * @return Returns true if there is one more row to return, false otherwise. * @throws IOException throw if an exception happens while reading a batch. */ private boolean ensureBatch() throws IOException { if (nextRow >= rowsInBatch) { // No more rows available in the Rows array. nextRow = 0; // Try to read the next batch if rows from the ORC file. boolean moreRows = orcRowsReader.nextBatch(rowBatch); if (moreRows) { // Load the data into the Rows array. rowsInBatch = fillRows(rows, schema, rowBatch, selectedFields); } return moreRows; } // there is at least one Row left in the Rows array. return true; } @Override public Row nextRecord(Row reuse) throws IOException { // return the next row return rows[this.nextRow++]; } @Override public TypeInformation<Row> getProducedType() { return rowType; } // -------------------------------------------------------------------------------------------- // Custom serialization methods // -------------------------------------------------------------------------------------------- private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(batchSize); this.conf.write(out); out.writeUTF(schema.toString()); out.writeInt(selectedFields.length); for (int f : selectedFields) { out.writeInt(f); } out.writeInt(conjunctPredicates.size()); for (Predicate p : conjunctPredicates) { out.writeObject(p); } } @SuppressWarnings("unchecked") private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { batchSize = in.readInt(); org.apache.hadoop.conf.Configuration configuration = new org.apache.hadoop.conf.Configuration(); configuration.readFields(in); if (this.conf == null) { this.conf = configuration; } this.schema = TypeDescription.fromString(in.readUTF()); this.selectedFields = new int[in.readInt()]; for (int i = 0; i < selectedFields.length; i++) { this.selectedFields[i] = in.readInt(); } this.conjunctPredicates = new ArrayList<>(); int numPreds = in.readInt(); for (int i = 0; i < numPreds; i++) { conjunctPredicates.add((Predicate) in.readObject()); } } @Override public boolean supportsMultiPaths() { return true; } // -------------------------------------------------------------------------------------------- // Getter methods for tests // -------------------------------------------------------------------------------------------- @VisibleForTesting Configuration getConfiguration() { return conf; } @VisibleForTesting int getBatchSize() { return batchSize; } @VisibleForTesting String getSchema() { return schema.toString(); } // -------------------------------------------------------------------------------------------- // Classes to define predicates // -------------------------------------------------------------------------------------------- /** * A filter predicate that can be evaluated by the OrcRowInputFormat. */ public abstract static class Predicate implements Serializable { protected abstract SearchArgument.Builder add(SearchArgument.Builder builder); } abstract static class ColumnPredicate extends Predicate { final String columnName; final PredicateLeaf.Type literalType; ColumnPredicate(String columnName, PredicateLeaf.Type literalType) { this.columnName = columnName; this.literalType = literalType; } Object castLiteral(Serializable literal) { switch (literalType) { case LONG: if (literal instanceof Byte) { return new Long((Byte) literal); } else if (literal instanceof Short) { return new Long((Short) literal); } else if (literal instanceof Integer) { return new Long((Integer) literal); } else if (literal instanceof Long) { return literal; } else { throw new IllegalArgumentException("A predicate on a LONG column requires an integer " + "literal, i.e., Byte, Short, Integer, or Long."); } case FLOAT: if (literal instanceof Float) { return new Double((Float) literal); } else if (literal instanceof Double) { return literal; } else if (literal instanceof BigDecimal) { return ((BigDecimal) literal).doubleValue(); } else { throw new IllegalArgumentException("A predicate on a FLOAT column requires a floating " + "literal, i.e., Float or Double."); } case STRING: if (literal instanceof String) { return literal; } else { throw new IllegalArgumentException("A predicate on a STRING column requires a floating " + "literal, i.e., Float or Double."); } case BOOLEAN: if (literal instanceof Boolean) { return literal; } else { throw new IllegalArgumentException( "A predicate on a BOOLEAN column requires a Boolean literal."); } case DATE: if (literal instanceof Date) { return literal; } else { throw new IllegalArgumentException( "A predicate on a DATE column requires a java.sql.Date literal."); } case TIMESTAMP: if (literal instanceof Timestamp) { return literal; } else { throw new IllegalArgumentException( "A predicate on a TIMESTAMP column requires a java.sql.Timestamp literal."); } case DECIMAL: if (literal instanceof BigDecimal) { return new HiveDecimalWritable(HiveDecimal.create((BigDecimal) literal)); } else { throw new IllegalArgumentException( "A predicate on a DECIMAL column requires a BigDecimal literal."); } default: throw new IllegalArgumentException("Unknown literal type " + literalType); } } } abstract static class BinaryPredicate extends ColumnPredicate { final Serializable literal; BinaryPredicate(String columnName, PredicateLeaf.Type literalType, Serializable literal) { super(columnName, literalType); this.literal = literal; } } /** * An EQUALS predicate that can be evaluated by the OrcRowInputFormat. */ public static class Equals extends BinaryPredicate { /** * Creates an EQUALS predicate. * * @param columnName The column to check. * @param literalType The type of the literal. * @param literal The literal value to check the column against. */ public Equals(String columnName, PredicateLeaf.Type literalType, Serializable literal) { super(columnName, literalType, literal); } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.equals(columnName, literalType, castLiteral(literal)); } @Override public String toString() { return columnName + " = " + literal; } } /** * An EQUALS predicate that can be evaluated with Null safety by the OrcRowInputFormat. */ public static class NullSafeEquals extends BinaryPredicate { /** * Creates a null-safe EQUALS predicate. * * @param columnName The column to check. * @param literalType The type of the literal. * @param literal The literal value to check the column against. */ public NullSafeEquals(String columnName, PredicateLeaf.Type literalType, Serializable literal) { super(columnName, literalType, literal); } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.nullSafeEquals(columnName, literalType, castLiteral(literal)); } @Override public String toString() { return columnName + " = " + literal; } } /** * A LESS_THAN predicate that can be evaluated by the OrcRowInputFormat. */ public static class LessThan extends BinaryPredicate { /** * Creates a LESS_THAN predicate. * * @param columnName The column to check. * @param literalType The type of the literal. * @param literal The literal value to check the column against. */ public LessThan(String columnName, PredicateLeaf.Type literalType, Serializable literal) { super(columnName, literalType, literal); } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.lessThan(columnName, literalType, castLiteral(literal)); } @Override public String toString() { return columnName + " < " + literal; } } /** * A LESS_THAN_EQUALS predicate that can be evaluated by the OrcRowInputFormat. */ public static class LessThanEquals extends BinaryPredicate { /** * Creates a LESS_THAN_EQUALS predicate. * * @param columnName The column to check. * @param literalType The type of the literal. * @param literal The literal value to check the column against. */ public LessThanEquals(String columnName, PredicateLeaf.Type literalType, Serializable literal) { super(columnName, literalType, literal); } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.lessThanEquals(columnName, literalType, castLiteral(literal)); } @Override public String toString() { return columnName + " <= " + literal; } } /** * An IS_NULL predicate that can be evaluated by the OrcRowInputFormat. */ public static class IsNull extends ColumnPredicate { /** * Creates an IS_NULL predicate. * * @param columnName The column to check for null. * @param literalType The type of the column to check for null. */ public IsNull(String columnName, PredicateLeaf.Type literalType) { super(columnName, literalType); } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.isNull(columnName, literalType); } @Override public String toString() { return columnName + " IS NULL"; } } /** * An BETWEEN predicate that can be evaluated by the OrcRowInputFormat. */ public static class Between extends ColumnPredicate { private Serializable lowerBound; private Serializable upperBound; /** * Creates an BETWEEN predicate. * * @param columnName The column to check. * @param literalType The type of the literals. * @param lowerBound The literal value of the (inclusive) lower bound to check the column against. * @param upperBound The literal value of the (inclusive) upper bound to check the column against. */ public Between(String columnName, PredicateLeaf.Type literalType, Serializable lowerBound, Serializable upperBound) { super(columnName, literalType); this.lowerBound = lowerBound; this.upperBound = upperBound; } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { return builder.between(columnName, literalType, castLiteral(lowerBound), castLiteral(upperBound)); } @Override public String toString() { return lowerBound + " <= " + columnName + " <= " + upperBound; } } /** * An IN predicate that can be evaluated by the OrcRowInputFormat. */ public static class In extends ColumnPredicate { private Serializable[] literals; /** * Creates an IN predicate. * * @param columnName The column to check. * @param literalType The type of the literals. * @param literals The literal values to check the column against. */ public In(String columnName, PredicateLeaf.Type literalType, Serializable... literals) { super(columnName, literalType); this.literals = literals; } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { Object[] castedLiterals = new Object[literals.length]; for (int i = 0; i < literals.length; i++) { castedLiterals[i] = castLiteral(literals[i]); } return builder.in(columnName, literalType, (Object[]) castedLiterals); } @Override public String toString() { return columnName + " IN " + Arrays.toString(literals); } } /** * A NOT predicate to negate a predicate that can be evaluated by the OrcRowInputFormat. */ public static class Not extends Predicate { private final Predicate pred; /** * Creates a NOT predicate. * * @param predicate The predicate to negate. */ public Not(Predicate predicate) { this.pred = predicate; } protected SearchArgument.Builder add(SearchArgument.Builder builder) { return pred.add(builder.startNot()).end(); } protected Predicate child() { return pred; } @Override public String toString() { return "NOT(" + pred.toString() + ")"; } } /** * An OR predicate that can be evaluated by the OrcRowInputFormat. */ public static class Or extends Predicate { private final Predicate[] preds; /** * Creates an OR predicate. * * @param predicates The disjunctive predicates. */ public Or(Predicate... predicates) { this.preds = predicates; } @Override protected SearchArgument.Builder add(SearchArgument.Builder builder) { SearchArgument.Builder withOr = builder.startOr(); for (Predicate p : preds) { withOr = p.add(withOr); } return withOr.end(); } protected Iterable<Predicate> children() { return Arrays.asList(preds); } @Override public String toString() { return "OR(" + Arrays.toString(preds) + ")"; } } }