Java tutorial
/* * Copyright 2015 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.template.etl.common; import com.google.common.base.Throwables; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.RecordWriter; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.db.DBConfiguration; import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat; import org.apache.hadoop.mapreduce.lib.db.DBWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.sql.Connection; import java.sql.Driver; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.SQLException; /** * Class that extends {@link DBOutputFormat} to load the database driver class correctly. * * @param <K> - Key passed to this class to be written * @param <V> - Value passed to this class to be written. The value is ignored. * * {@inheritDoc} */ public class ETLDBOutputFormat<K extends DBWritable, V> extends DBOutputFormat<K, V> { private static final Logger LOG = LoggerFactory.getLogger(ETLDBOutputFormat.class); private Driver driver; private JDBCDriverShim driverShim; @Override public RecordWriter<K, V> getRecordWriter(TaskAttemptContext context) throws IOException { Configuration conf = context.getConfiguration(); DBConfiguration dbConf = new DBConfiguration(conf); String tableName = dbConf.getOutputTableName(); String[] fieldNames = dbConf.getOutputFieldNames(); if (fieldNames == null) { fieldNames = new String[dbConf.getOutputFieldCount()]; } try { Connection connection = getConnection(conf); PreparedStatement statement = connection.prepareStatement(constructQuery(tableName, fieldNames)); return new DBRecordWriter(connection, statement) { @Override public void close(TaskAttemptContext context) throws IOException { super.close(context); try { DriverManager.deregisterDriver(driverShim); } catch (SQLException e) { throw new IOException(e); } } }; } catch (Exception ex) { throw new IOException(ex.getMessage()); } } private Connection getConnection(Configuration conf) { Connection connection; try { String url = conf.get(DBConfiguration.URL_PROPERTY); try { // throws SQLException if no suitable driver is found DriverManager.getDriver(url); } catch (SQLException e) { if (driverShim == null) { if (driver == null) { ClassLoader classLoader = conf.getClassLoader(); @SuppressWarnings("unchecked") Class<? extends Driver> driverClass = (Class<? extends Driver>) classLoader .loadClass(conf.get(DBConfiguration.DRIVER_CLASS_PROPERTY)); driver = driverClass.newInstance(); // De-register the default driver that gets registered when driver class is loaded. DBUtils.deregisterAllDrivers(driverClass); } driverShim = new JDBCDriverShim(driver); DriverManager.registerDriver(driverShim); LOG.debug("Registered JDBC driver via shim {}. Actual Driver {}.", driverShim, driver); } } if (conf.get(DBConfiguration.USERNAME_PROPERTY) == null) { connection = DriverManager.getConnection(url); } else { connection = DriverManager.getConnection(url, conf.get(DBConfiguration.USERNAME_PROPERTY), conf.get(DBConfiguration.PASSWORD_PROPERTY)); } connection.setAutoCommit(false); connection.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE); } catch (Exception e) { throw Throwables.propagate(e); } return connection; } @Override public String constructQuery(String table, String[] fieldNames) { String query = super.constructQuery(table, fieldNames); // Strip the ';' at the end since Oracle doesn't like it. // TODO: Perhaps do a conditional if we can find a way to tell that this is going to Oracle // However, tested this to work on Mysql and Oracle return query.substring(0, query.length() - 1); } }