oxis.yologp.YOLogPDescriptor.java Source code

Java tutorial

Introduction

Here is the source code for oxis.yologp.YOLogPDescriptor.java

Source

/**
 * Copyright (C) 2014 EMBL - European Bioinformatics Institute
 *
 * All rights reserved. This file is part of the YOLogP project.
 *
 * author: oXis (Benjamin Roques)
 *
 * This program is free software; you can redistribute it and/or modify it under
 * the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0
 * International License, please visit
 * http://creativecommons.org/licenses/by-nc-sa/4.0/. All we ask is that proper
 * credit is given for my work, which includes - but is not limited to - adding
 * the above copyright notice to the beginning of your source code files, and to
 * any copyright notice that you may distribute with programs based on this
 * work.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.
 */
package oxis.yologp;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.openscience.cdk.exception.CDKException;
import org.openscience.cdk.interfaces.IAtomContainer;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class YOLogPDescriptor {

    private int poolSize = Runtime.getRuntime().availableProcessors() * 2;
    private int timeout = 80;

    private String path = "./";

    private List<DrugStruct> listDrug = new ArrayList<>();
    private List<DrugStruct> listDrugComputed = new ArrayList<>();

    private RandomForest model;

    /**
     * Add instantiated DrugStruct to the list.
     *
     * @param drugs list of DrugStruct
     */
    private YOLogPDescriptor(List<DrugStruct> drugs) {

        this.listDrug.addAll(drugs);
    }

    /**
     * Instantiate a list of DrugStruct.
     *
     * @param ListContainer List<IAtomContainer> list of IAtomContainer
     */
    public static YOLogPDescriptor fromAtomContainer(List<IAtomContainer> ListContainer) throws CDKException {
        List<DrugStruct> setupContainers = new ArrayList<>();
        for (IAtomContainer IAdrug : ListContainer) {
            try {
                setupContainers.add(new DrugStruct(IAdrug));
            } catch (CDKException ex) {
                throw new CDKException("Cannot convert IAtomContainer to DrugStruct", ex);
            }
        }
        return new YOLogPDescriptor(setupContainers);
    }

    /**
     * Instantiate a DrugStruct.
     *
     * @param container IAtomContainer
     */
    public static YOLogPDescriptor fromAtomContainer(IAtomContainer container) throws CDKException {
        try {
            return new YOLogPDescriptor(Collections.singletonList(new DrugStruct(container)));
        } catch (CDKException ex) {
            throw new CDKException("Cannot convert IAtomContainer to DrugStruct", ex);
        }
    }

    /**
     * Instantiate a list of DrugStruct.
     *
     * @param listSmiles List<String>
     */
    public static YOLogPDescriptor fromSmiles(List<String> listSmiles) throws CDKException {
        List<DrugStruct> setupContainers = new ArrayList<>();
        for (String smiles : listSmiles) {
            try {
                setupContainers.add(new DrugStruct(smiles));
            } catch (CDKException ex) {
                throw new CDKException("Cannot instanciate IAtomContainer with " + smiles, ex);
            }
        }
        return new YOLogPDescriptor(setupContainers);
    }

    /**
     * Instantiate a DrugStruct.
     *
     * @param smiles String
     */
    public static YOLogPDescriptor fromSmiles(String smiles) throws CDKException {
        try {
            return new YOLogPDescriptor(Collections.singletonList(new DrugStruct(smiles)));
        } catch (CDKException ex) {
            throw new CDKException("Cannot instanciate IAtomContainer with " + smiles, ex);
        }
    }

    /**
     * Add instantiated DrugStruct to the list.
     *
     * @param listDrugStruct list of DrugStruct
     */
    public static YOLogPDescriptor fromDrugStruct(List<DrugStruct> listDrugStruct) {
        List<DrugStruct> setupContainers = new ArrayList<>();
        for (DrugStruct drug : listDrugStruct) {
            setupContainers.add(drug);
        }
        return new YOLogPDescriptor(setupContainers);
    }

    /**
     * Add instantiated DrugStruct to the list.
     *
     * @param drugStruct list of DrugStruct
     */
    public static YOLogPDescriptor fromDrugStruct(DrugStruct drugStruct) {
        return new YOLogPDescriptor(Collections.singletonList(drugStruct));
    }

    /**
     * Computed all properties and then predict the LogP.
     *
     * @return List<DrugStruct> list of DrugStruct
     */
    public List<DrugStruct> calculate() {

        try {
            model = (RandomForest) weka.core.SerializationHelper.read(getClass().getResourceAsStream("/rf.model"));
        } catch (Exception ex) {
            Logger.getLogger(Main.class.getName()).log(Level.SEVERE, "Cannot load default model", ex);
            System.exit(0);
        }

        compute();

        try {
            predict();
        } catch (Exception ex) {
            Logger.getLogger(YOLogPDescriptor.class.getName()).log(Level.SEVERE, "Prediction error", ex);
        }

        //System.out.println("Done!");
        return listDrug;
    }

    public List<DrugStruct> calculate(String name) {

        try {
            model = (RandomForest) weka.core.SerializationHelper.read(path + name);
        } catch (Exception ex) {
            Logger.getLogger(Main.class.getName()).log(Level.SEVERE, "Could not load model " + name, ex);
            System.exit(0);
        }

        compute();

        try {
            predict();
        } catch (Exception ex) {
            Logger.getLogger(YOLogPDescriptor.class.getName()).log(Level.SEVERE, "Prediction error", ex);
        }

        //System.out.println("Done!");
        return listDrug;
    }

    /**
     * Computed all properties.
     *
     */
    private void compute() {

        List<Future<DrugStruct>> worker = new ArrayList<>();
        ExecutorService pool = Executors.newFixedThreadPool(poolSize);

        for (DrugStruct ds : listDrug) {
            worker.add(pool.submit(ds));
        }

        int i = 0;
        DrugStruct tmp;
        for (Future<DrugStruct> result : worker) {
            try {
                tmp = result.get(timeout, TimeUnit.SECONDS);
                listDrugComputed.add(tmp);
            } catch (InterruptedException | ExecutionException | TimeoutException ex) {
                result.cancel(true);
                Logger.getLogger(YOLogPDescriptor.class.getName()).log(Level.WARNING,
                        "One worker killed, it took too much time to compute. Timeout = " + timeout + "s", ex);
            }
        }
        pool.shutdown();

        //Free mem
        listDrug = listDrugComputed;
        listDrugComputed = null;
    }

    /**
     * Predict the LogP.
     *
     */
    private void predict() throws Exception {

        Instances instances = buildDataset();

        Map<Object, Object> properties;
        for (DrugStruct drugStruct : listDrug) {

            if (drugStruct.drug.getProperty("flag")) {
                properties = drugStruct.drug.getProperties();
                Instance instance = new DenseInstance(instances.numAttributes()); //28 + 1024
                instance.setDataset(instances);
                for (Object propKey : properties.keySet()) {
                    if (!(propKey.equals("hash") || propKey.equals("flag") || propKey.equals("smiles"))) {
                        try {
                            instance.setValue(instances.attribute(propKey.toString()),
                                    Double.parseDouble(properties.get(propKey).toString()));
                        } catch (NullPointerException ex) {
                            Logger.getLogger(YOLogPDescriptor.class.getName()).log(Level.WARNING,
                                    "Property not used: {0}", propKey.toString());
                        }
                    }
                }

                double predicted = model.classifyInstance(instance);
                predicted = Math.round(predicted * 100) / 100.0d;
                instance.setClassValue(predicted);
                instances.add(instance);
                drugStruct.drug.setProperty("predicted", predicted);
            }
        }
    }

    /**
     * Train a model, erase the other one
     *
     * @param String name of the model to save
     */
    public void train(String name) throws Exception {

        compute();

        Instances instances = buildDataset();

        model = new RandomForest();

        Map<Object, Object> properties;
        for (DrugStruct drugStruct : listDrug) {

            if (drugStruct.drug.getProperty("flag")) {
                properties = drugStruct.drug.getProperties();
                Instance instance = new DenseInstance(instances.numAttributes()); //28 + 1024
                instance.setDataset(instances);
                for (Object propKey : properties.keySet()) {
                    if (!(propKey.equals("hash") || propKey.equals("flag") || propKey.equals("smiles"))) {
                        try {
                            instance.setValue(instances.attribute(propKey.toString()),
                                    Double.parseDouble(properties.get(propKey).toString()));
                        } catch (NullPointerException ex) {
                            Logger.getLogger(YOLogPDescriptor.class.getName()).log(Level.WARNING,
                                    "Property not used: {0}", propKey.toString());
                        }
                    }
                }
                instance.setClassValue(drugStruct.getLogP());
                instances.add(instance);
            }
        }
        model.setNumFeatures(200);
        model.setNumTrees(400);
        model.setMaxDepth(0);
        model.buildClassifier(instances);

        weka.core.SerializationHelper.write(path + name, model);
    }

    /**
     * Build the Dataset structure
     *
     * @return Instances
     */
    private Instances buildDataset() {

        ArrayList<Attribute> attInfo = new ArrayList<>();

        attInfo.add(new Attribute("AromaticBondsCountDescriptor"));
        attInfo.add(new Attribute("AromaticAtomsCountDescriptor"));
        attInfo.add(new Attribute("BondPartialSigmaChargeMax"));
        attInfo.add(new Attribute("AutocorrelationDescriptorMass"));
        attInfo.add(new Attribute("EffectiveAtomPolarizabilityMea"));
        attInfo.add(new Attribute("MDEDescriptor5"));
        attInfo.add(new Attribute("BasicGroupCountDescriptor"));
        attInfo.add(new Attribute("AutocorrelationDescriptorCharge2"));
        attInfo.add(new Attribute("APolDescriptor"));
        attInfo.add(new Attribute("BCUTDescriptor3"));
        attInfo.add(new Attribute("AutocorrelationDescriptorPolarizability2"));
        attInfo.add(new Attribute("HBondDonorCountDescriptor"));
        attInfo.add(new Attribute("PartialTChargePEOEMin"));
        attInfo.add(new Attribute("PartialTChargePEOEMed"));
        attInfo.add(new Attribute("PartialSigmaChargeMed"));
        attInfo.add(new Attribute("BondPartialSigmaChargeMea"));
        attInfo.add(new Attribute("PartialSigmaChargeMin"));
        attInfo.add(new Attribute("WeightedPathDescriptor5"));
        attInfo.add(new Attribute("TPSADescriptor"));
        attInfo.add(new Attribute("AutocorrelationDescriptorPolarizability"));
        attInfo.add(new Attribute("CarbonTypesDescriptor4"));
        attInfo.add(new Attribute("ALOGPDescriptor2"));
        attInfo.add(new Attribute("HBondAcceptorCountDescriptor"));
        attInfo.add(new Attribute("MannholdLogPDescriptor"));
        attInfo.add(new Attribute("FractionalPSADescriptor"));
        attInfo.add(new Attribute("ALOGPDescriptor3"));
        attInfo.add(new Attribute("ALOGPDescriptor"));
        attInfo.add(new Attribute("XLogPDescriptor"));
        for (int i = 0; i < 1024; i++) {
            attInfo.add(new Attribute("X" + i));
        }
        attInfo.add(new Attribute("logp"));

        Instances instances = new Instances("logp", attInfo, 0);
        instances.setClassIndex(instances.attribute("logp").index());

        return instances;
    }

    /**
     * Print all the compound into an XML file.
     *
     * @param name String, name of the file
     */
    public void printXML(String name) throws Exception {

        File file = new File(path + name);
        file.delete();
        try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(file, true)))) {
            writer.write("<Drugs>");
            for (DrugStruct drug : listDrug) {
                if (drug.drug.getProperty("flag")) {
                    drug.printXML(writer);
                }
            }
            writer.write("</Drugs>");
        }
    }

    /**
     * Print all predicted values.
     *
     */
    public void printPredicted() {
        for (DrugStruct drug : listDrug) {
            System.out.println(drug.drug.getProperty("logp"));
        }
    }

    /**
     * Get all predicted values.
     *
     * @return List<Double>
     */
    public List<Double> getPredicted() {
        List<Double> listPredicted = new ArrayList<>();

        for (DrugStruct drug : listDrug) {
            listPredicted.add((double) drug.drug.getProperty("logp"));
        }

        return listPredicted;
    }

    /**
     * Print all the compound into an CSV file.
     *
     * @param name String, name of the file
     */
    public void printCSV(String name) throws Exception {

        File file = new File(path + name);
        file.delete();
        try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(file, true)))) {
            writer.write("id,name,smiles,predicted");
            writer.println();
            for (DrugStruct drug : listDrug) {
                if (drug.drug.getProperty("flag")) {
                    drug.printCSV(writer);
                }
            }
        }
    }

    /**
     * Set the size of the pool. Number of thread.
     *
     * @param poolSize integer
     */
    public void setPoolSize(int poolSize) {
        this.poolSize = poolSize;
    }

    /**
     * Get the size of the pool.
     *
     * @return integer
     */
    public int getPoolSize() {
        return poolSize;
    }

    /**
     * Set the timeout value. When a thread work more than timeout, it is
     * killed.
     *
     * @param timeout integer
     */
    public void setTimeout(int timeout) {
        this.timeout = timeout;
    }

    /**
     * Get the timeout.
     *
     * @return integer
     */
    public int getTimeout() {
        return timeout;
    }

    /**
     * Set the default path. For XML and CVS writing
     *
     * @param path String
     */
    public void setPath(String path) {
        this.path = path;
    }

    /**
     * Get the default path.
     *
     * @return String
     */
    public String getPath() {
        return path;
    }
}