org.apache.parquet.avro.AvroWriteSupport.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.parquet.avro.AvroWriteSupport.java

Source

/* 
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.parquet.avro;

import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.IndexedRecord;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.parquet.Preconditions;

/**
 * Avro implementation of {@link WriteSupport} for generic, specific, and
 * reflect models. Use {@link AvroParquetWriter} or
 * {@link AvroParquetOutputFormat} rather than using this class directly.
 */
public class AvroWriteSupport<T> extends WriteSupport<T> {

    public static final String AVRO_DATA_SUPPLIER = "parquet.avro.write.data.supplier";

    public static void setAvroDataSupplier(Configuration configuration,
            Class<? extends AvroDataSupplier> suppClass) {
        configuration.set(AVRO_DATA_SUPPLIER, suppClass.getName());
    }

    static final String AVRO_SCHEMA = "parquet.avro.schema";
    private static final Schema MAP_KEY_SCHEMA = Schema.create(Schema.Type.STRING);

    public static final String WRITE_OLD_LIST_STRUCTURE = "parquet.avro.write-old-list-structure";
    static final boolean WRITE_OLD_LIST_STRUCTURE_DEFAULT = true;

    private static final String MAP_REPEATED_NAME = "key_value";
    private static final String MAP_KEY_NAME = "key";
    private static final String MAP_VALUE_NAME = "value";
    private static final String LIST_REPEATED_NAME = "list";
    private static final String OLD_LIST_REPEATED_NAME = "array";
    static final String LIST_ELEMENT_NAME = "element";

    private RecordConsumer recordConsumer;
    private MessageType rootSchema;
    private Schema rootAvroSchema;
    private GenericData model;
    private ListWriter listWriter;

    public AvroWriteSupport() {
    }

    /**
     * @deprecated use {@link AvroWriteSupport(MessageType, Schema, Configuration)}
     */
    @Deprecated
    public AvroWriteSupport(MessageType schema, Schema avroSchema) {
        this.rootSchema = schema;
        this.rootAvroSchema = avroSchema;
        this.model = null;
    }

    public AvroWriteSupport(MessageType schema, Schema avroSchema, GenericData model) {
        this.rootSchema = schema;
        this.rootAvroSchema = avroSchema;
        this.model = model;
    }

    /**
     * @see org.apache.parquet.avro.AvroParquetOutputFormat#setSchema(org.apache.hadoop.mapreduce.Job, org.apache.avro.Schema)
     */
    public static void setSchema(Configuration configuration, Schema schema) {
        configuration.set(AVRO_SCHEMA, schema.toString());
    }

    @Override
    public WriteContext init(Configuration configuration) {
        if (rootAvroSchema == null) {
            this.rootAvroSchema = new Schema.Parser().parse(configuration.get(AVRO_SCHEMA));
            this.rootSchema = new AvroSchemaConverter().convert(rootAvroSchema);
        }

        if (model == null) {
            this.model = getDataModel(configuration);
        }

        boolean writeOldListStructure = configuration.getBoolean(WRITE_OLD_LIST_STRUCTURE,
                WRITE_OLD_LIST_STRUCTURE_DEFAULT);
        if (writeOldListStructure) {
            this.listWriter = new TwoLevelListWriter();
        } else {
            this.listWriter = new ThreeLevelListWriter();
        }

        Map<String, String> extraMetaData = new HashMap<String, String>();
        extraMetaData.put(AvroReadSupport.AVRO_SCHEMA_METADATA_KEY, rootAvroSchema.toString());
        return new WriteContext(rootSchema, extraMetaData);
    }

    @Override
    public void prepareForWrite(RecordConsumer recordConsumer) {
        this.recordConsumer = recordConsumer;
    }

    // overloaded version for backward compatibility
    @SuppressWarnings("unchecked")
    public void write(IndexedRecord record) {
        recordConsumer.startMessage();
        writeRecordFields(rootSchema, rootAvroSchema, record);
        recordConsumer.endMessage();
    }

    @Override
    public void write(T record) {
        recordConsumer.startMessage();
        writeRecordFields(rootSchema, rootAvroSchema, record);
        recordConsumer.endMessage();
    }

    private void writeRecord(GroupType schema, Schema avroSchema, Object record) {
        recordConsumer.startGroup();
        writeRecordFields(schema, avroSchema, record);
        recordConsumer.endGroup();
    }

    private void writeRecordFields(GroupType schema, Schema avroSchema, Object record) {
        List<Type> fields = schema.getFields();
        List<Schema.Field> avroFields = avroSchema.getFields();
        int index = 0; // parquet ignores Avro nulls, so index may differ
        for (int avroIndex = 0; avroIndex < avroFields.size(); avroIndex++) {
            Schema.Field avroField = avroFields.get(avroIndex);
            if (avroField.schema().getType().equals(Schema.Type.NULL)) {
                continue;
            }
            Type fieldType = fields.get(index);
            Object value = model.getField(record, avroField.name(), avroIndex);
            if (value != null) {
                recordConsumer.startField(fieldType.getName(), index);
                writeValue(fieldType, avroField.schema(), value);
                recordConsumer.endField(fieldType.getName(), index);
            } else if (fieldType.isRepetition(Type.Repetition.REQUIRED)) {
                throw new RuntimeException("Null-value for required field: " + avroField.name());
            }
            index++;
        }
    }

    private <V> void writeMap(GroupType schema, Schema avroSchema, Map<CharSequence, V> map) {
        GroupType innerGroup = schema.getType(0).asGroupType();
        Type keyType = innerGroup.getType(0);
        Type valueType = innerGroup.getType(1);

        recordConsumer.startGroup(); // group wrapper (original type MAP)
        if (map.size() > 0) {
            recordConsumer.startField(MAP_REPEATED_NAME, 0);

            for (Map.Entry<CharSequence, V> entry : map.entrySet()) {
                recordConsumer.startGroup(); // repeated group key_value, middle layer
                recordConsumer.startField(MAP_KEY_NAME, 0);
                writeValue(keyType, MAP_KEY_SCHEMA, entry.getKey());
                recordConsumer.endField(MAP_KEY_NAME, 0);
                V value = entry.getValue();
                if (value != null) {
                    recordConsumer.startField(MAP_VALUE_NAME, 1);
                    writeValue(valueType, avroSchema.getValueType(), value);
                    recordConsumer.endField(MAP_VALUE_NAME, 1);
                } else if (!valueType.isRepetition(Type.Repetition.OPTIONAL)) {
                    throw new RuntimeException("Null map value for " + avroSchema.getName());
                }
                recordConsumer.endGroup();
            }

            recordConsumer.endField(MAP_REPEATED_NAME, 0);
        }
        recordConsumer.endGroup();
    }

    private void writeUnion(GroupType parquetSchema, Schema avroSchema, Object value) {
        recordConsumer.startGroup();

        // ResolveUnion will tell us which of the union member types to
        // deserialise.
        int avroIndex = model.resolveUnion(avroSchema, value);

        // For parquet's schema we skip nulls
        GroupType parquetGroup = parquetSchema.asGroupType();
        int parquetIndex = avroIndex;
        for (int i = 0; i < avroIndex; i++) {
            if (avroSchema.getTypes().get(i).getType().equals(Schema.Type.NULL)) {
                parquetIndex--;
            }
        }

        // Sparsely populated method of encoding unions, each member has its own
        // set of columns.
        String memberName = "member" + parquetIndex;
        recordConsumer.startField(memberName, parquetIndex);
        writeValue(parquetGroup.getType(parquetIndex), avroSchema.getTypes().get(avroIndex), value);
        recordConsumer.endField(memberName, parquetIndex);

        recordConsumer.endGroup();
    }

    @SuppressWarnings("unchecked")
    private void writeValue(Type type, Schema avroSchema, Object value) {
        Schema nonNullAvroSchema = AvroSchemaConverter.getNonNull(avroSchema);
        Schema.Type avroType = nonNullAvroSchema.getType();
        if (avroType.equals(Schema.Type.BOOLEAN)) {
            recordConsumer.addBoolean((Boolean) value);
        } else if (avroType.equals(Schema.Type.INT)) {
            if (value instanceof Character) {
                recordConsumer.addInteger((Character) value);
            } else {
                recordConsumer.addInteger(((Number) value).intValue());
            }
        } else if (avroType.equals(Schema.Type.LONG)) {
            recordConsumer.addLong(((Number) value).longValue());
        } else if (avroType.equals(Schema.Type.FLOAT)) {
            recordConsumer.addFloat(((Number) value).floatValue());
        } else if (avroType.equals(Schema.Type.DOUBLE)) {
            recordConsumer.addDouble(((Number) value).doubleValue());
        } else if (avroType.equals(Schema.Type.BYTES)) {
            if (value instanceof byte[]) {
                recordConsumer.addBinary(Binary.fromReusedByteArray((byte[]) value));
            } else {
                recordConsumer.addBinary(Binary.fromReusedByteBuffer((ByteBuffer) value));
            }
        } else if (avroType.equals(Schema.Type.STRING)) {
            recordConsumer.addBinary(fromAvroString(value));
        } else if (avroType.equals(Schema.Type.RECORD)) {
            writeRecord(type.asGroupType(), nonNullAvroSchema, value);
        } else if (avroType.equals(Schema.Type.ENUM)) {
            recordConsumer.addBinary(Binary.fromString(value.toString()));
        } else if (avroType.equals(Schema.Type.ARRAY)) {
            listWriter.writeList(type.asGroupType(), nonNullAvroSchema, value);
        } else if (avroType.equals(Schema.Type.MAP)) {
            writeMap(type.asGroupType(), nonNullAvroSchema, (Map<CharSequence, ?>) value);
        } else if (avroType.equals(Schema.Type.UNION)) {
            writeUnion(type.asGroupType(), nonNullAvroSchema, value);
        } else if (avroType.equals(Schema.Type.FIXED)) {
            recordConsumer.addBinary(Binary.fromReusedByteArray(((GenericFixed) value).bytes()));
        }
    }

    private Binary fromAvroString(Object value) {
        if (value instanceof Utf8) {
            Utf8 utf8 = (Utf8) value;
            return Binary.fromReusedByteArray(utf8.getBytes(), 0, utf8.getByteLength());
        }
        return Binary.fromString(value.toString());
    }

    private static GenericData getDataModel(Configuration conf) {
        Class<? extends AvroDataSupplier> suppClass = conf.getClass(AVRO_DATA_SUPPLIER, SpecificDataSupplier.class,
                AvroDataSupplier.class);
        return ReflectionUtils.newInstance(suppClass, conf).get();
    }

    private abstract class ListWriter {

        protected abstract void writeCollection(GroupType type, Schema schema, Collection<?> collection);

        protected abstract void writeObjectArray(GroupType type, Schema schema, Object[] array);

        protected abstract void startArray();

        protected abstract void endArray();

        public void writeList(GroupType schema, Schema avroSchema, Object value) {
            recordConsumer.startGroup(); // group wrapper (original type LIST)
            if (value instanceof Collection) {
                writeCollection(schema, avroSchema, (Collection) value);
            } else {
                Class<?> arrayClass = value.getClass();
                Preconditions.checkArgument(arrayClass.isArray(),
                        "Cannot write unless collection or array: " + arrayClass.getName());
                writeJavaArray(schema, avroSchema, arrayClass, value);
            }
            recordConsumer.endGroup();
        }

        public void writeJavaArray(GroupType schema, Schema avroSchema, Class<?> arrayClass, Object value) {
            Class<?> elementClass = arrayClass.getComponentType();

            if (!elementClass.isPrimitive()) {
                writeObjectArray(schema, avroSchema, (Object[]) value);
                return;
            }

            switch (avroSchema.getElementType().getType()) {
            case BOOLEAN:
                Preconditions.checkArgument(elementClass == boolean.class,
                        "Cannot write as boolean array: " + arrayClass.getName());
                writeBooleanArray((boolean[]) value);
                break;
            case INT:
                if (elementClass == byte.class) {
                    writeByteArray((byte[]) value);
                } else if (elementClass == char.class) {
                    writeCharArray((char[]) value);
                } else if (elementClass == short.class) {
                    writeShortArray((short[]) value);
                } else if (elementClass == int.class) {
                    writeIntArray((int[]) value);
                } else {
                    throw new IllegalArgumentException("Cannot write as an int array: " + arrayClass.getName());
                }
                break;
            case LONG:
                Preconditions.checkArgument(elementClass == long.class,
                        "Cannot write as long array: " + arrayClass.getName());
                writeLongArray((long[]) value);
                break;
            case FLOAT:
                Preconditions.checkArgument(elementClass == float.class,
                        "Cannot write as float array: " + arrayClass.getName());
                writeFloatArray((float[]) value);
                break;
            case DOUBLE:
                Preconditions.checkArgument(elementClass == double.class,
                        "Cannot write as double array: " + arrayClass.getName());
                writeDoubleArray((double[]) value);
                break;
            default:
                throw new IllegalArgumentException(
                        "Cannot write " + avroSchema.getElementType() + " array: " + arrayClass.getName());
            }
        }

        protected void writeBooleanArray(boolean[] array) {
            if (array.length > 0) {
                startArray();
                for (boolean element : array) {
                    recordConsumer.addBoolean(element);
                }
                endArray();
            }
        }

        protected void writeByteArray(byte[] array) {
            if (array.length > 0) {
                startArray();
                for (byte element : array) {
                    recordConsumer.addInteger(element);
                }
                endArray();
            }
        }

        protected void writeShortArray(short[] array) {
            if (array.length > 0) {
                startArray();
                for (short element : array) {
                    recordConsumer.addInteger(element);
                }
                endArray();
            }
        }

        protected void writeCharArray(char[] array) {
            if (array.length > 0) {
                startArray();
                for (char element : array) {
                    recordConsumer.addInteger(element);
                }
                endArray();
            }
        }

        protected void writeIntArray(int[] array) {
            if (array.length > 0) {
                startArray();
                for (int element : array) {
                    recordConsumer.addInteger(element);
                }
                endArray();
            }
        }

        protected void writeLongArray(long[] array) {
            if (array.length > 0) {
                startArray();
                for (long element : array) {
                    recordConsumer.addLong(element);
                }
                endArray();
            }
        }

        protected void writeFloatArray(float[] array) {
            if (array.length > 0) {
                startArray();
                for (float element : array) {
                    recordConsumer.addFloat(element);
                }
                endArray();
            }
        }

        protected void writeDoubleArray(double[] array) {
            if (array.length > 0) {
                startArray();
                for (double element : array) {
                    recordConsumer.addDouble(element);
                }
                endArray();
            }
        }
    }

    /**
     * For backward-compatibility. This preserves how lists were written in 1.x.
     */
    private class TwoLevelListWriter extends ListWriter {
        @Override
        public void writeCollection(GroupType schema, Schema avroSchema, Collection<?> array) {
            if (array.size() > 0) {
                recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
                for (Object elt : array) {
                    writeValue(schema.getType(0), avroSchema.getElementType(), elt);
                }
                recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
            }
        }

        @Override
        protected void writeObjectArray(GroupType type, Schema schema, Object[] array) {
            if (array.length > 0) {
                recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
                for (Object element : array) {
                    writeValue(type.getType(0), schema.getElementType(), element);
                }
                recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
            }
        }

        @Override
        protected void startArray() {
            recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
        }

        @Override
        protected void endArray() {
            recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
        }
    }

    private class ThreeLevelListWriter extends ListWriter {
        @Override
        protected void writeCollection(GroupType type, Schema schema, Collection<?> collection) {
            if (collection.size() > 0) {
                recordConsumer.startField(LIST_REPEATED_NAME, 0);
                GroupType repeatedType = type.getType(0).asGroupType();
                Type elementType = repeatedType.getType(0);
                for (Object element : collection) {
                    recordConsumer.startGroup(); // repeated group array, middle layer
                    if (element != null) {
                        recordConsumer.startField(LIST_ELEMENT_NAME, 0);
                        writeValue(elementType, schema.getElementType(), element);
                        recordConsumer.endField(LIST_ELEMENT_NAME, 0);
                    } else if (!elementType.isRepetition(Type.Repetition.OPTIONAL)) {
                        throw new RuntimeException("Null list element for " + schema.getName());
                    }
                    recordConsumer.endGroup();
                }
                recordConsumer.endField(LIST_REPEATED_NAME, 0);
            }
        }

        @Override
        protected void writeObjectArray(GroupType type, Schema schema, Object[] array) {
            if (array.length > 0) {
                recordConsumer.startField(LIST_REPEATED_NAME, 0);
                GroupType repeatedType = type.getType(0).asGroupType();
                Type elementType = repeatedType.getType(0);
                for (Object element : array) {
                    recordConsumer.startGroup(); // repeated group array, middle layer
                    if (element != null) {
                        recordConsumer.startField(LIST_ELEMENT_NAME, 0);
                        writeValue(elementType, schema.getElementType(), element);
                        recordConsumer.endField(LIST_ELEMENT_NAME, 0);
                    } else if (!elementType.isRepetition(Type.Repetition.OPTIONAL)) {
                        throw new RuntimeException("Null list element for " + schema.getName());
                    }
                    recordConsumer.endGroup();
                }
                recordConsumer.endField(LIST_REPEATED_NAME, 0);
            }
        }

        @Override
        protected void startArray() {
            recordConsumer.startField(LIST_REPEATED_NAME, 0);
            recordConsumer.startGroup(); // repeated group array, middle layer
            recordConsumer.startField(LIST_ELEMENT_NAME, 0);
        }

        @Override
        protected void endArray() {
            recordConsumer.endField(LIST_ELEMENT_NAME, 0);
            recordConsumer.endGroup();
            recordConsumer.endField(LIST_REPEATED_NAME, 0);
        }
    }
}