Java tutorial
/* * $# * FOS Weka * * Copyright (C) 2013 Feedzai SA * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-3.0.html>. * #$ */ package com.feedzai.fos.impl.weka; import au.com.bytecode.opencsv.CSVReader; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.esotericsoftware.kryo.serializers.CollectionSerializer; import com.fasterxml.jackson.databind.ObjectMapper; import com.feedzai.fos.api.*; import com.feedzai.fos.common.kryo.CustomUUIDSerializer; import com.feedzai.fos.common.kryo.ScoringRequestEnvelope; import com.feedzai.fos.common.validation.NotBlank; import com.feedzai.fos.common.validation.NotNull; import com.feedzai.fos.impl.weka.config.WekaManagerConfig; import com.feedzai.fos.impl.weka.config.WekaModelConfig; import com.feedzai.fos.impl.weka.utils.WekaUtils; import com.feedzai.fos.impl.weka.utils.pmml.PMMLProducers; import com.feedzai.fos.impl.weka.utils.setter.InstanceSetter; import com.google.common.base.Joiner; import com.google.common.io.Files; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.SerializationUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import weka.classifiers.Classifier; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import java.io.*; import java.net.ServerSocket; import java.net.Socket; import java.util.*; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import static com.feedzai.fos.api.util.ManagerUtils.*; import static com.google.common.base.Preconditions.checkNotNull; /** * This class implements a manager that is able to train and score * using Weka classifiers. * <p/> * Aditionally, it also implements a Kryo endpoint for scoring to be used along * KryoScorer. * * @author Marco Jorge (marco.jorge@feedzai.com) * @author Miguel Duarte (miguel.duarte@feedzai.com) */ public class WekaManager implements Manager { private final static Logger logger = LoggerFactory.getLogger(WekaManager.class); private Thread acceptThread; private ServerSocket serverSocket; ObjectMapper mapper = new ObjectMapper(); private Map<UUID, WekaModelConfig> modelConfigs = new HashMap<>(); private WekaManagerConfig wekaManagerConfig; private WekaScorer wekaScorer; private KryoScoringEndpoint scorerHandler; private volatile boolean acceptThreadRunning = false; /** * Save dirty configurations to disk. * <p/> If saving configuration was not possible, a log is produced but no exception is thrown. */ private synchronized void saveConfiguration() { for (WekaModelConfig wekaModelConfig : modelConfigs.values()) { if (wekaModelConfig.isDirty() && wekaModelConfig.getModelConfig().isStoreModel()) { try { String modelConfigJson = mapper.writeValueAsString(wekaModelConfig.getModelConfig()); // create a new file because this model has never been written if (wekaModelConfig.getHeader() == null) { File file = File.createTempFile(wekaModelConfig.getId().toString(), "." + WekaManagerConfig.HEADER_EXTENSION, wekaManagerConfig.getHeaderLocation()); wekaModelConfig.setHeader(file); } FileUtils.write((wekaModelConfig).getHeader(), modelConfigJson); wekaModelConfig.setDirty(false /* contents have been updated so the model is no longer dirty*/); } catch (IOException e) { logger.error("Could not store configuration for model '{}' (will continue to save others)", wekaModelConfig.getId(), e); } } } } /** * Create a new manager from the given configuration. * <p/> Will lookup any headers files and to to instantiate the model. * <p/> If a model fails, a log is produced but loading other models will continue (no exception is thrown). * * @param wekaManagerConfig the manager configuration */ public WekaManager(WekaManagerConfig wekaManagerConfig) { checkNotNull(wekaManagerConfig, "Manager config cannot be null"); this.wekaManagerConfig = wekaManagerConfig; Collection<File> headers = FileUtils.listFiles(wekaManagerConfig.getHeaderLocation(), new String[] { WekaManagerConfig.HEADER_EXTENSION }, true); for (File header : headers) { logger.trace("Reading model file '{}'", header); FileInputStream fileInputStream = null; try { fileInputStream = new FileInputStream(header); String modelConfigJson = IOUtils.toString(fileInputStream); ModelConfig modelConfig = mapper.readValue(modelConfigJson, ModelConfig.class); WekaModelConfig wekaModelConfig = new WekaModelConfig(modelConfig, wekaManagerConfig); wekaModelConfig.setHeader(header); wekaModelConfig.setDirty(false /* not changed so far */); if (modelConfigs.containsKey(wekaModelConfig.getId())) { logger.error( "Model with ID '{}' is duplicated in the configuration (the configuration from '{}' is discarded)", wekaModelConfig.getId(), header.getAbsolutePath()); } else { modelConfigs.put(wekaModelConfig.getId(), wekaModelConfig); } } catch (Exception e) { logger.error("Could not load from '{}' (continuing to load others)", header, e); } finally { IOUtils.closeQuietly(fileInputStream); } } this.wekaScorer = new WekaScorer(modelConfigs, wekaManagerConfig); try { int port = wekaManagerConfig.getScoringPort(); this.serverSocket = new ServerSocket(port); serverSocket.setReuseAddress(true); final int max_threads = wekaManagerConfig.getMaxSimultaneousScoringThreads(); Runnable acceptRunnable = new Runnable() { ExecutorService executor = Executors.newFixedThreadPool(max_threads); @Override public void run() { acceptThreadRunning = true; try { while (acceptThreadRunning && Thread.currentThread().isInterrupted() == false) { Socket client = serverSocket.accept(); client.setTcpNoDelay(true); scorerHandler = new KryoScoringEndpoint(client, wekaScorer); executor.submit(scorerHandler); } } catch (IOException e) { logger.error(e.getMessage(), e); } } }; acceptThread = new Thread(acceptRunnable); acceptThread.start(); } catch (IOException e) { logger.error(e.getMessage(), e); } } @Override public synchronized UUID addModel(ModelConfig config, Model model) throws FOSException { try { UUID uuid = getUuid(config); File modelFile = createModelFile(wekaManagerConfig.getHeaderLocation(), uuid, model); WekaModelConfig wekaModelConfig = new WekaModelConfig(config, wekaManagerConfig); wekaModelConfig.setId(uuid); wekaModelConfig.setModelDescriptor(getModelDescriptor(model, modelFile)); modelConfigs.put(uuid, wekaModelConfig); wekaScorer.addOrUpdate(wekaModelConfig); saveConfiguration(); logger.debug("Model {} added", uuid); return uuid; } catch (IOException e) { throw new FOSException(e); } } @Override public synchronized UUID addModel(ModelConfig config, @NotBlank ModelDescriptor descriptor) throws FOSException { UUID uuid = getUuid(config); WekaModelConfig wekaModelConfig = new WekaModelConfig(config, wekaManagerConfig); wekaModelConfig.setId(uuid); wekaModelConfig.setModelDescriptor(descriptor); modelConfigs.put(uuid, wekaModelConfig); wekaScorer.addOrUpdate(wekaModelConfig); saveConfiguration(); logger.debug("Model {} added", uuid); return uuid; } @Override public synchronized void removeModel(UUID modelId) throws FOSException { WekaModelConfig wekaModelConfig = modelConfigs.remove(modelId); if (wekaModelConfig == null) { logger.warn("Could not remove model with id {} because it does not exists", modelId); return; } wekaScorer.removeModel(modelId); if (wekaModelConfig.getModelConfig().isStoreModel()) { // delete the header & model file (or else it will be picked up on the next restart) wekaModelConfig.getHeader().delete(); // only delete if is in our header location if (!wekaManagerConfig.getHeaderLocation().toURI().relativize(wekaModelConfig.getModel().toURI()) .isAbsolute()) { wekaModelConfig.getModel().delete(); } } logger.debug("Model {} removed", modelId); } @Override public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig) throws FOSException { WekaModelConfig wekaModelConfig = this.modelConfigs.get(modelId); wekaModelConfig.update(modelConfig); wekaScorer.addOrUpdate(wekaModelConfig); saveConfiguration(); logger.debug("Model {} reconfigured", modelId); } @Override public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig, Model model) throws FOSException { try { File modelFile = createModelFile(wekaManagerConfig.getHeaderLocation(), modelId, model); WekaModelConfig wekaModelConfig = this.modelConfigs.get(modelId); wekaModelConfig.update(modelConfig); ModelDescriptor descriptor = getModelDescriptor(model, modelFile); wekaModelConfig.setModelDescriptor(descriptor); wekaScorer.addOrUpdate(wekaModelConfig); saveConfiguration(); logger.debug("Model {} reconfigured", modelId); } catch (IOException e) { throw new FOSException(e); } } @Override public synchronized void reconfigureModel(UUID modelId, ModelConfig modelConfig, @NotBlank ModelDescriptor descriptor) throws FOSException { File file = new File(descriptor.getModelFilePath()); WekaModelConfig wekaModelConfig = this.modelConfigs.get(modelId); wekaModelConfig.update(modelConfig); wekaModelConfig.setModelDescriptor(descriptor); wekaScorer.addOrUpdate(wekaModelConfig); saveConfiguration(); } @Override @NotNull public synchronized Map<UUID, ModelConfig> listModels() { Map<UUID, ModelConfig> result = new HashMap<>(modelConfigs.size()); for (Map.Entry<UUID, WekaModelConfig> entry : modelConfigs.entrySet()) { result.put(entry.getKey(), entry.getValue().getModelConfig()); } return result; } @Override @NotNull public WekaScorer getScorer() { return wekaScorer; } @Override public synchronized UUID trainAndAdd(ModelConfig config, List<Object[]> instances) throws FOSException { Model trainedModel = train(config, instances); return addModel(config, trainedModel); } @Override public synchronized UUID trainAndAddFile(ModelConfig config, String path) throws FOSException { Model trainedModel = trainFile(config, path); return addModel(config, trainedModel); } @Override public Model train(ModelConfig config, List<Object[]> instances) throws FOSException { checkNotNull(instances, "Instances must be supplied"); checkNotNull(config, "Config must be supplied"); long time = System.currentTimeMillis(); WekaModelConfig wekaModelConfig = new WekaModelConfig(config, wekaManagerConfig); Classifier classifier = WekaClassifierFactory.create(config); FastVector attributes = WekaUtils.instanceFields2Attributes(wekaModelConfig.getClassIndex(), config.getAttributes()); InstanceSetter[] instanceSetters = WekaUtils.instanceFields2ValueSetters(config.getAttributes(), InstanceType.TRAINING); Instances wekaInstances = new Instances(config.getProperty(WekaModelConfig.CLASSIFIER_IMPL), attributes, instances.size()); for (Object[] objects : instances) { wekaInstances.add(WekaUtils.objectArray2Instance(objects, instanceSetters, attributes)); } trainClassifier(wekaModelConfig.getClassIndex(), classifier, wekaInstances); final byte[] bytes = SerializationUtils.serialize(classifier); logger.debug("Trained model with {} instances in {}ms", instances.size(), (System.currentTimeMillis() - time)); return new ModelBinary(bytes); } @Override public Model trainFile(ModelConfig config, String path) throws FOSException { checkNotNull(path, "Config must be supplied"); checkNotNull(path, "Path must be supplied"); long time = System.currentTimeMillis(); WekaModelConfig wekaModelConfig = new WekaModelConfig(config, wekaManagerConfig); Classifier classifier = WekaClassifierFactory.create(config); List<Attribute> attributeList = config.getAttributes(); FastVector attributes = WekaUtils.instanceFields2Attributes(wekaModelConfig.getClassIndex(), config.getAttributes()); InstanceSetter[] instanceSetters = WekaUtils.instanceFields2ValueSetters(config.getAttributes(), InstanceType.TRAINING); List<Instance> instances = new ArrayList(); String[] line; try { FileReader fileReader = new FileReader(path); CSVReader csvReader = new CSVReader(fileReader); while ((line = csvReader.readNext()) != null) { // parsing is done by InstanceSetter's instances.add(WekaUtils.objectArray2Instance(line, instanceSetters, attributes)); } } catch (Exception e) { throw new FOSException(e.getMessage(), e); } Instances wekaInstances = new Instances(config.getProperty(WekaModelConfig.CLASSIFIER_IMPL), attributes, instances.size()); for (Instance instance : instances) { wekaInstances.add(instance); } trainClassifier(wekaModelConfig.getClassIndex(), classifier, wekaInstances); final byte[] bytes = SerializationUtils.serialize(classifier); logger.debug("Trained model with {} instances in {}ms", instances.size(), (System.currentTimeMillis() - time)); return new ModelBinary(bytes); } /** * Will save the configuration to file. * * @throws FOSException when there are IO problems writing the configuration to file */ @Override public synchronized void close() throws FOSException { acceptThreadRunning = false; if (scorerHandler != null) { scorerHandler.running = false; scorerHandler.close(); } IOUtils.closeQuietly(serverSocket); saveConfiguration(); } /** * Returns a new {@link com.feedzai.fos.api.ModelDescriptor} for the given {@code model} and {@code file}. * * @param model The {@link Model} with the classifier. * @param modelFile The file where the model will be saved. * @return A new {@link com.feedzai.fos.api.ModelDescriptor} * @throws FOSException If the given {@code model} is of an unknown instance. */ private ModelDescriptor getModelDescriptor(Model model, File modelFile) throws FOSException { if (model instanceof ModelBinary) { return new ModelDescriptor(ModelDescriptor.Format.BINARY, modelFile.getAbsolutePath()); } else if (model instanceof ModelPMML) { return new ModelDescriptor(ModelDescriptor.Format.PMML, modelFile.getAbsolutePath()); } else { throw new FOSException("Unsupported Model type '" + model.getClass().getSimpleName() + "'."); } } /** * Trains the given {@code classifier} using the given {@link com.feedzai.fos.impl.weka.config.WekaModelConfig modelConfig} * and {@link weka.core.Instances wekaInstances}. * * @param classIndex The index of the class. * @param classifier The classifier to be trained. * @param wekaInstances The training instances. * @throws FOSException If it fails to train the classifier. */ private void trainClassifier(int classIndex, Classifier classifier, Instances wekaInstances) throws FOSException { wekaInstances.setClassIndex(classIndex == -1 ? wekaInstances.numAttributes() - 1 : classIndex); try { classifier.buildClassifier(wekaInstances); } catch (Exception e) { throw new FOSException(e.getMessage(), e); } } @Override public void save(UUID uuid, String savepath) throws FOSException { try { File source = modelConfigs.get(uuid).getModel(); File destination = new File(savepath); Files.copy(source, destination); } catch (Exception e) { throw new FOSException("Unable to save model " + uuid + " to " + savepath, e); } } @Override public void saveAsPMML(UUID uuid, String saveFilePath, boolean compress) throws FOSException { Classifier classifier = wekaScorer.getClassifier(uuid); File target = new File(saveFilePath); PMMLProducers.produce(classifier, target, compress); } /** * This class should be used to perform scoring requests * to a remote FOS instance that suports kryo end points. * <p/> * It listens on a socket input stream for Kryo serialized {@link ScoringRequestEnvelope} * objects, extracts them and forwards them to the local scorer. * <p/> * The scoring result is then Kryo encoded on the socket output stream. */ private static class KryoScoringEndpoint implements Runnable { public static final int BUFFER_SIZE = 1024; Socket client; Scorer scorer; private volatile boolean running = true; private KryoScoringEndpoint(Socket client, Scorer scorer) throws IOException { this.client = client; this.scorer = scorer; } @Override public void run() { Kryo kryo = new Kryo(); kryo.addDefaultSerializer(UUID.class, new CustomUUIDSerializer()); // workaround for java.util.Arrays$ArrayList missing default constructor kryo.register(Arrays.asList().getClass(), new CollectionSerializer() { protected Collection create(Kryo kryo, Input input, Class<Collection> type) { return new ArrayList(); } }); Input input = new Input(BUFFER_SIZE); Output output = new Output(BUFFER_SIZE); ScoringRequestEnvelope request = null; try (InputStream is = client.getInputStream(); OutputStream os = client.getOutputStream()) { input.setInputStream(is); output.setOutputStream(os); while (running) { request = kryo.readObject(input, ScoringRequestEnvelope.class); List<double[]> scores = scorer.score(request.getUUIDs(), request.getInstance()); kryo.writeObject(output, scores); output.flush(); os.flush(); } } catch (Exception e) { if (request != null) { logger.error("Error scoring instance {} for models {}", Arrays.toString(request.getInstance()), Arrays.toString(request.getUUIDs().toArray()), e); } else { logger.error("Error scoring instance", e); } } finally { IOUtils.closeQuietly(client); running = false; } } public void close() { running = false; IOUtils.closeQuietly(client); } } }