co.cask.hydrator.plugin.batch.spark.test.GDTreeTest.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.hydrator.plugin.batch.spark.test.GDTreeTest.java

Source

/*
 * Copyright  2016 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.hydrator.plugin.batch.spark.test;

import co.cask.cdap.api.data.format.StructuredRecord;
import co.cask.cdap.api.data.schema.Schema;
import co.cask.cdap.api.dataset.table.Table;
import co.cask.cdap.datapipeline.DataPipelineApp;
import co.cask.cdap.datapipeline.SmartWorkflow;
import co.cask.cdap.etl.api.batch.SparkCompute;
import co.cask.cdap.etl.api.batch.SparkSink;
import co.cask.cdap.etl.mock.batch.MockSink;
import co.cask.cdap.etl.mock.batch.MockSource;
import co.cask.cdap.etl.mock.test.HydratorTestBase;
import co.cask.cdap.etl.proto.v2.ETLBatchConfig;
import co.cask.cdap.etl.proto.v2.ETLPlugin;
import co.cask.cdap.etl.proto.v2.ETLStage;
import co.cask.cdap.proto.Id;
import co.cask.cdap.proto.artifact.AppRequest;
import co.cask.cdap.proto.artifact.ArtifactSummary;
import co.cask.cdap.proto.id.ApplicationId;
import co.cask.cdap.proto.id.ArtifactId;
import co.cask.cdap.proto.id.NamespaceId;
import co.cask.cdap.test.ApplicationManager;
import co.cask.cdap.test.DataSetManager;
import co.cask.cdap.test.TestConfiguration;
import co.cask.cdap.test.WorkflowManager;
import co.cask.hydrator.plugin.spark.GDTreeClassifier;
import co.cask.hydrator.plugin.spark.GDTreeTrainer;
import co.cask.hydrator.plugin.spark.TwitterStreamingSource;
import com.google.common.collect.ImmutableMap;
import org.apache.commons.io.FileUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

public class GDTreeTest extends HydratorTestBase {

    @ClassRule
    public static final TestConfiguration CONFIG = new TestConfiguration("explore.enabled", false);
    @ClassRule
    public static TemporaryFolder temporaryFolder = new TemporaryFolder();

    protected static final ArtifactId DATAPIPELINE_ARTIFACT_ID = NamespaceId.DEFAULT.artifact("data-pipeline",
            "3.5.0");
    protected static final ArtifactSummary DATAPIPELINE_ARTIFACT = new ArtifactSummary("data-pipeline", "3.5.0");

    private static final String LABELED_RECORDS = "labeledRecords";

    private final Schema schema = Schema.recordOf("flightData",
            Schema.Field.of("dofM", Schema.nullableOf(Schema.of(Schema.Type.INT))),
            Schema.Field.of("dofW", Schema.nullableOf(Schema.of(Schema.Type.INT))),
            Schema.Field.of("carrier", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("tailNum", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
            Schema.Field.of("flightNum", Schema.nullableOf(Schema.of(Schema.Type.INT))),
            Schema.Field.of("originId", Schema.nullableOf(Schema.of(Schema.Type.INT))),
            Schema.Field.of("origin", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
            Schema.Field.of("destId", Schema.nullableOf(Schema.of(Schema.Type.INT))),
            Schema.Field.of("dest", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
            Schema.Field.of("scheduleDepTime", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("deptime", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("depDelayMins", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("scheduledArrTime", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("arrTime", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("arrDelay", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("elapsedTime", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
            Schema.Field.of("distance", Schema.nullableOf(Schema.of(Schema.Type.INT))));

    private static File sourceFolder;

    @BeforeClass
    public static void setupTest() throws Exception {
        // add the artifact for etl batch app
        setupBatchArtifacts(DATAPIPELINE_ARTIFACT_ID, DataPipelineApp.class);

        // add artifact for spark plugins
        addPluginArtifact(NamespaceId.DEFAULT.artifact("spark-plugins", "1.0.0"), DATAPIPELINE_ARTIFACT_ID,
                GDTreeTrainer.class, TwitterStreamingSource.class);

        sourceFolder = temporaryFolder.newFolder("GDTree");
    }

    @AfterClass
    public static void tearDown() throws Exception {
        temporaryFolder.delete();
    }

    @Before
    public void copyFiles() throws Exception {
        URL testFileUrl = this.getClass().getResource("/trainData.csv");
        FileUtils.copyFile(new File(testFileUrl.getFile()), new File(sourceFolder, "/trainData.csv"));
    }

    @Test
    public void testSparkSinkAndCompute() throws Exception {
        // use the SparkSink(GDTreeTrainer) to train a model
        testSinglePhaseWithSparkSink();
        // use the SparkCompute(GDTreeClassifier) to classify the records
        testSinglePhaseWithSparkCompute();
    }

    private void testSinglePhaseWithSparkSink() throws Exception {
        /*
         * source --> sparksink
         */
        String inputTable = "flight-data";
        Map<String, String> properties = new ImmutableMap.Builder<String, String>()
                .put("fileSetName", "gd-tree-model").put("path", "output")
                .put("featuresToInclude",
                        "dofM,dofW,carrier,originId,destId,scheduleDepTime,scheduledArrTime,elapsedTime")
                .put("labelField", "delayed").put("maxClass", "2").put("maxDepth", "10").put("maxIteration", "5")
                .build();

        ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
                .addStage(new ETLStage("source", MockSource.getPlugin(inputTable, getTrainerSchema(schema))))
                .addStage(new ETLStage("customsink",
                        new ETLPlugin(GDTreeTrainer.PLUGIN_NAME, SparkSink.PLUGIN_TYPE, properties, null)))
                .addConnection("source", "customsink").build();

        AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(DATAPIPELINE_ARTIFACT, etlConfig);
        ApplicationId appId = NamespaceId.DEFAULT.app("SinglePhaseApp");
        ApplicationManager appManager = deployApplication(appId.toId(), appRequest);

        // send records from sample data to train the model
        List<StructuredRecord> messagesToWrite = new ArrayList<>();
        messagesToWrite.addAll(getInputData());

        // write records to source
        DataSetManager<Table> inputManager = getDataset(Id.Namespace.DEFAULT, inputTable);
        MockSource.writeInput(inputManager, messagesToWrite);

        // manually trigger the pipeline
        WorkflowManager workflowManager = appManager.getWorkflowManager(SmartWorkflow.NAME);
        workflowManager.start();
        workflowManager.waitForFinish(5, TimeUnit.MINUTES);
    }

    //Get data from file to be used for training the model.
    private List<StructuredRecord> getInputData() throws IOException {
        List<StructuredRecord> messagesToWrite = new ArrayList<>();
        File file = new File(sourceFolder.getAbsolutePath(), "/trainData.csv");
        BufferedReader bufferedInputStream = new BufferedReader(new FileReader(file));
        String line;
        while ((line = bufferedInputStream.readLine()) != null) {
            String[] flightData = line.split(",");
            Double depDelayMins = Double.parseDouble(flightData[11]);
            //For binary classification create delayed field containing values 1.0 and 0.0 depending on the delay time.
            double delayed = depDelayMins > 40 ? 1.0 : 0.0;
            messagesToWrite.add(new Flight(Integer.parseInt(flightData[0]), Integer.parseInt(flightData[1]),
                    Double.parseDouble(flightData[2]), flightData[3], Integer.parseInt(flightData[4]),
                    Integer.parseInt(flightData[5]), flightData[6], Integer.parseInt(flightData[7]), flightData[8],
                    Integer.parseInt(flightData[9]), Double.parseDouble(flightData[10]), depDelayMins,
                    Double.parseDouble(flightData[12]), Double.parseDouble(flightData[13]),
                    Double.parseDouble(flightData[14]), Double.parseDouble(flightData[15]),
                    Integer.parseInt(flightData[16]), delayed).toStructuredRecord());
        }
        return messagesToWrite;
    }

    private Schema getTrainerSchema(Schema schema) {
        List<Schema.Field> fields = new ArrayList<>(schema.getFields());
        fields.add(Schema.Field.of("delayed", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))));
        return Schema.recordOf(schema.getRecordName() + ".predicted", fields);
    }

    private void testSinglePhaseWithSparkCompute() throws Exception {
        String inputTable = "spark-compute";
        /*
         * source --> sparkcompute --> sink
         */
        ETLPlugin classifierPlugin = new ETLPlugin(GDTreeClassifier.PLUGIN_NAME, SparkCompute.PLUGIN_TYPE,
                ImmutableMap.of("fileSetName", "gd-tree-model", "path", "output", "featuresToExclude",
                        "tailNum,flightNum,origin,dest," + "deptime,depDelayMins,arrTime,arrDelay,distance",
                        "predictionField", "delayed"),
                null);
        ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
                .addStage(new ETLStage("source", MockSource.getPlugin(inputTable, schema)))
                .addStage(new ETLStage("sparkcompute", classifierPlugin))
                .addStage(new ETLStage("sink", MockSink.getPlugin(LABELED_RECORDS)))
                .addConnection("source", "sparkcompute").addConnection("sparkcompute", "sink").build();

        AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(DATAPIPELINE_ARTIFACT, etlConfig);
        ApplicationId appId = NamespaceId.DEFAULT.app("SinglePhaseApp");
        ApplicationManager appManager = deployApplication(appId.toId(), appRequest);

        // Flight records to be labeled.
        List<StructuredRecord> messagesToWrite = new ArrayList<>();
        messagesToWrite.add(new Flight(4, 6, 1.0, "N327AA", 1, 12478, "JFK", 12892, "LAX", 900, 1005.0, 65.0,
                1225.0, 1324.0, 59.0, 385.0, 2475).toStructuredRecord());
        messagesToWrite.add(new Flight(25, 6, 2.0, "N0EGMQ", 3419, 10397, "ATL", 12953, "LGA", 1150, 1229.0, 39.0,
                1359.0, 1448.0, 49.0, 129.0, 762).toStructuredRecord());
        messagesToWrite.add(new Flight(4, 6, 3.0, "N14991", 6159, 13930, "ORD", 13198, "MCI", 2030, 2118.0, 48.0,
                2205.0, 2321.0, 76.0, 95.0, 403).toStructuredRecord());
        messagesToWrite.add(new Flight(29, 3, 1.0, "N355AA", 2407, 12892, "LAX", 11298, "DFW", 1025, 1023.0, 0.0,
                1530.0, 1523.0, 0.0, 185.0, 1235).toStructuredRecord());
        messagesToWrite.add(new Flight(2, 4, 4.0, "N919DE", 1908, 13930, "ORD", 11433, "DTW", 1641, 1902.0, 141.0,
                1905.0, 2117.0, 132.0, 84.0, 235).toStructuredRecord());
        messagesToWrite.add(new Flight(2, 4, 4.0, "N933DN", 1791, 10397, "ATL", 15376, "TUS", 1855, 2014.0, 79.0,
                2108.0, 2159.0, 51.0, 253.0, 1541).toStructuredRecord());

        DataSetManager<Table> inputManager = getDataset(Id.Namespace.DEFAULT, inputTable);
        MockSource.writeInput(inputManager, messagesToWrite);

        // manually trigger the pipeline
        WorkflowManager workflowManager = appManager.getWorkflowManager(SmartWorkflow.NAME);
        workflowManager.start();
        workflowManager.waitForFinish(5, TimeUnit.MINUTES);

        DataSetManager<Table> labeledTexts = getDataset(LABELED_RECORDS);
        List<StructuredRecord> structuredRecords = MockSink.readOutput(labeledTexts);

        Set<Flight> expected = new HashSet<>();
        expected.add(new Flight(4, 6, 1.0, "N327AA", 1, 12478, "JFK", 12892, "LAX", 900, 1005.0, 65.0, 1225.0,
                1324.0, 59.0, 385.0, 2475, 1.0));
        expected.add(new Flight(25, 6, 2.0, "N0EGMQ", 3419, 10397, "ATL", 12953, "LGA", 1150, 1229.0, 39.0, 1359.0,
                1448.0, 49.0, 129.0, 762, 0.0));
        expected.add(new Flight(4, 6, 3.0, "N14991", 6159, 13930, "ORD", 13198, "MCI", 2030, 2118.0, 48.0, 2205.0,
                2321.0, 76.0, 95.0, 403, 1.0));
        expected.add(new Flight(29, 3, 1.0, "N355AA", 2407, 12892, "LAX", 11298, "DFW", 1025, 1023.0, 0.0, 1530.0,
                1523.0, 0.0, 185.0, 1235, 0.0));
        expected.add(new Flight(2, 4, 4.0, "N919DE", 1908, 13930, "ORD", 11433, "DTW", 1641, 1902.0, 141.0, 1905.0,
                2117.0, 132.0, 84.0, 235, 1.0));
        expected.add(new Flight(2, 4, 4.0, "N933DN", 1791, 10397, "ATL", 15376, "TUS", 1855, 2014.0, 79.0, 2108.0,
                2159.0, 51.0, 253.0, 1541, 1.0));

        Set<Flight> results = new HashSet<>();
        for (StructuredRecord structuredRecord : structuredRecords) {
            results.add(Flight.fromStructuredRecord(structuredRecord));
        }

        Assert.assertEquals(expected, results);
    }
}