predictors.HelixIndexer.java Source code

Java tutorial

Introduction

Here is the source code for predictors.HelixIndexer.java

Source

//    <TMSEG: Prediction of Transmembrane Helices in Proteins.>
//    Copyright (C) 2014  Michael Bernhofer
//
//    This program is free software: you can redistribute it and/or modify
//    it under the terms of the GNU General Public License as published by
//    the Free Software Foundation, either version 3 of the License, or
//    (at your option) any later version.
//
//    This program is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU General Public License for more details.
//
//    You should have received a copy of the GNU General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

package predictors;

import io.ModelHandler;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import util.ErrorUtils;
import util.Globals;
import util.Mappings;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.converters.ArffSaver;
import data.Protein;
import data.Pssm;

/**
 * Class to predict transmembrane residues within a protein.
 */
public class HelixIndexer {

    public static final int indexNotTmh = 0;
    public static final int indexTmh = 1;
    public static final int indexSignal = 2;

    private ArrayList<Attribute> attributes = null;
    private Instances dataset = null;
    private Classifier classifier = null;
    private boolean isTrained = false;

    private double[] globalConsAa = null;
    private double[] globalNonConsAa = null;

    public HelixIndexer() {
        this.buildAttributes();
        this.initialize();
    }

    /**
     * Initializes the attributes list for the Weka arff-format.
     */
    private void buildAttributes() {
        this.attributes = new ArrayList<Attribute>();

        for (int i = 1; i <= ((2 * Globals.INDEXER_WINDOW_SIZE) + 1); ++i) {
            for (int j = 0; j < 21; ++j) {
                this.attributes.add(new Attribute(Mappings.intToAa(j) + "_" + i));
            }
        }

        this.attributes.add(new Attribute("conserved_hydrophobicity"));
        this.attributes.add(new Attribute("non-conserved_hydrophobicity"));

        this.attributes.add(new Attribute("conserved_hydrophobic"));
        this.attributes.add(new Attribute("non-conserved_hydrophobic"));

        this.attributes.add(new Attribute("conserved_pos_charged"));
        this.attributes.add(new Attribute("non-conserved_pos_charged"));

        this.attributes.add(new Attribute("conserved_neg_charged"));
        this.attributes.add(new Attribute("non-conserved_neg_charged"));

        this.attributes.add(new Attribute("conserved_polar"));
        this.attributes.add(new Attribute("non-conserved_polar"));

        ArrayList<String> lengths = new ArrayList<String>();

        lengths.add(String.valueOf("0"));
        lengths.add(String.valueOf("1"));
        lengths.add(String.valueOf("2"));
        lengths.add(String.valueOf("3"));
        lengths.add(String.valueOf("4"));

        this.attributes.add(new Attribute("n-term_distance", lengths));
        this.attributes.add(new Attribute("c-term_distance", lengths));
        this.attributes.add(new Attribute("global_length", lengths));

        for (int j = 0; j < 20; ++j) {
            this.attributes.add(new Attribute("global_conserved_" + Mappings.intToAa(j)));
            this.attributes.add(new Attribute("global_non-conserved_" + Mappings.intToAa(j)));
        }

        ArrayList<String> classes = new ArrayList<String>();

        classes.add(String.valueOf(HelixIndexer.indexNotTmh));
        classes.add(String.valueOf(HelixIndexer.indexTmh));
        classes.add(String.valueOf(HelixIndexer.indexSignal));

        this.attributes.add(new Attribute("class", classes));
    }

    /**
     * Initializes the classifier and dataset.
     */
    public void initialize() {
        this.isTrained = false;
        this.classifier = null;
        this.dataset = new Instances("HelixIndexer Model", this.attributes, 0);

        this.dataset.setClassIndex(this.attributes.size() - 1);
    }

    /**
     * Inputs a given list of proteins for the training data.
     * 
     * @param proteins
     */
    public void input(List<Protein> proteins) {
        for (Protein protein : proteins) {
            this.input(protein);
        }
    }

    /**
     * Inputs a given protein for the training data.
     * 
     * @param protein
     */
    public void input(Protein protein) {
        if (protein == null) {
            return;
        }
        if (protein.getStructure() == null) {
            return;
        }
        if (protein.getPssm() == null) {
            return;
        }

        Pssm pssm = protein.getPssm();
        int length = pssm.getLength();
        char[] structure = protein.getStructure();

        if (pssm.getLength() != structure.length) {
            ErrorUtils.printError(HelixIndexer.class,
                    "PSSM and structure annotation length do not match for " + protein.getName(), null);

            return;
        }

        this.globalComposition(pssm);

        for (int i = 0; i < length; ++i) {
            if (Mappings.ssToInt(structure[i]) != Mappings.indexUnknown) {
                this.addWindowToDatabase(pssm, i, structure);
            }
        }
    }

    /**
     * Predicts transmembrane residues for a given list of proteins.
     * 
     * @param proteins
     */
    public void predict(List<Protein> proteins) {
        for (Protein protein : proteins) {
            this.predict(protein);
        }
    }

    /**
     * Predicts transmembrane residues for a given protein.
     * 
     * @param protein
     */
    public void predict(Protein protein) {
        if (protein == null || protein.getPssm() == null) {
            return;
        }

        Pssm pssm = protein.getPssm();
        int length = pssm.getLength();
        int[] scoresSol = new int[length];
        int[] scoresTmh = new int[length];
        int[] scoresSig = new int[length];

        this.globalComposition(pssm);

        //slide window along the sequence
        for (int i = 0; i < length; ++i) {
            try {
                Instance window = this.buildInstance(pssm, i);

                window.isMissing((Attribute) this.attributes.get(this.attributes.size() - 1));
                window.setDataset(this.dataset);

                double[] probabilities = this.classifier.distributionForInstance(window);

                if (i < 40) {
                    scoresSol[i] = (int) (1000 * probabilities[HelixIndexer.indexNotTmh]);
                    scoresTmh[i] = (int) (1000 * probabilities[HelixIndexer.indexTmh]);
                    scoresSig[i] = (int) (1000 * probabilities[HelixIndexer.indexSignal]);
                } else {
                    scoresSol[i] = (int) (1000 * probabilities[HelixIndexer.indexNotTmh]);
                    scoresTmh[i] = (int) (1000 * probabilities[HelixIndexer.indexTmh]);
                    scoresSig[i] = 0;
                }
            } catch (Exception e) {
                ErrorUtils.printError(HelixIndexer.class, "Prediction failed for " + protein.getHeader(), e);

                return;
            }
        }

        //save scores into the protein
        protein.setSolRaw(scoresSol);
        protein.setTmhRaw(scoresTmh);
        protein.setSigRaw(scoresSig);
    }

    /**
     * Analyzes a given window and saves it in the database.
     * 
     * @param pssm
     * @param windowCenter
     * @param structure
     */
    private void addWindowToDatabase(Pssm pssm, int windowCenter, char[] structure) {
        int index = Mappings.ssToInt(structure[windowCenter]);
        Instance window = this.buildInstance(pssm, windowCenter);

        if (index == Mappings.indexTmh) {
            index = HelixIndexer.indexTmh;
        } else if (index == Mappings.indexSignal) {
            index = HelixIndexer.indexSignal;
        } else {
            index = HelixIndexer.indexNotTmh;
        }

        window.setValue((Attribute) this.attributes.get(this.attributes.size() - 1), index);
        window.setDataset(this.dataset);

        this.dataset.add(window);
    }

    /**
     * Calculates the amino acid composition of the whole protein.
     * 
     * @param pssm
     */
    private void globalComposition(Pssm pssm) {
        int conserved = 0;
        int nonConserved = 0;
        double[] globalConserved = new double[20];
        double[] globalNonConserved = new double[20];

        for (int i = 0; i < pssm.getLength(); ++i) {
            for (int j = 0; j < 20; ++j) {
                int score = pssm.getScore(i, j);

                if (score > 0) {
                    ++globalConserved[j];
                    ++conserved;
                } else if (score < 0) {
                    ++globalNonConserved[j];
                    ++nonConserved;
                }
            }
        }

        for (int i = 0; i < 20; ++i) {
            globalConserved[i] = globalConserved[i] / conserved;
            globalNonConserved[i] = globalNonConserved[i] / nonConserved;
        }

        this.globalConsAa = globalConserved;
        this.globalNonConsAa = globalNonConserved;
    }

    /**
     * Converts a given window into a Weka Instance.
     * 
     * @param pssm
     * @param windowCenter
     * @return
     */
    private Instance buildInstance(Pssm pssm, int windowCenter) {
        SparseInstance window = new SparseInstance(this.attributes.size());
        int windowStart = windowCenter - Globals.INDEXER_WINDOW_SIZE;
        int windowStop = windowCenter + Globals.INDEXER_WINDOW_SIZE;
        int globalLenght = pssm.getLength();
        int nTermDist = windowCenter + 1;
        int cTermDist = globalLenght - windowCenter;
        int attIndex = 0;
        int conserved = 0;
        int nonConserved = 0;
        double consAvgHydro = 0;
        double nonConsAvgHydro = 0;
        double consHydro = 0;
        double nonConsHydro = 0;
        double consPCharged = 0;
        double nonConsPCharged = 0;
        double consNCharged = 0;
        double nonConsNCharged = 0;
        double consPolar = 0;
        double nonConsPolar = 0;

        //amino acid at position i in window
        for (int i = windowStart; i <= windowStop; ++i) {
            if (i >= 0 && i < globalLenght) {
                for (int j = 0; j < 20; ++j) {
                    int score = pssm.getScore(i, j);

                    if (Math.abs(i - windowCenter) <= Globals.INDEXER_INNER_WINDOW_SIZE) {
                        char aa = Mappings.intToAa(j);

                        if (score > 0) {
                            consAvgHydro += Mappings.hydrophobicity(aa);

                            if (Mappings.hydrophobicity(aa) > 0) {
                                ++consHydro;
                            }
                            if (Mappings.charge(aa) > 0) {
                                ++consPCharged;
                            }
                            if (Mappings.charge(aa) < 0) {
                                ++consNCharged;
                            }
                            if (Mappings.polarity(aa) > 0) {
                                ++consPolar;
                            }

                            ++conserved;
                        } else if (score < 0) {
                            nonConsAvgHydro += Mappings.hydrophobicity(aa);

                            if (Mappings.hydrophobicity(aa) > 0) {
                                ++nonConsHydro;
                            }
                            if (Mappings.charge(aa) > 0) {
                                ++nonConsPCharged;
                            }
                            if (Mappings.charge(aa) < 0) {
                                ++nonConsNCharged;
                            }
                            if (Mappings.polarity(aa) > 0) {
                                ++nonConsPolar;
                            }

                            ++nonConserved;
                        }
                    }

                    window.setValue((Attribute) this.attributes.get(attIndex++), score);
                }

                window.setValue((Attribute) this.attributes.get(attIndex++), -10);
            } else {
                for (int j = 0; j < 20; ++j) {
                    window.setValue((Attribute) this.attributes.get(attIndex++), 0);
                }

                window.setValue((Attribute) this.attributes.get(attIndex++), 10);
            }
        }

        conserved = Math.max(conserved, 1);
        nonConserved = Math.max(nonConserved, 1);

        window.setValue((Attribute) this.attributes.get(attIndex++), consAvgHydro / conserved);
        window.setValue((Attribute) this.attributes.get(attIndex++), nonConsAvgHydro / nonConserved);

        window.setValue((Attribute) this.attributes.get(attIndex++), consHydro / conserved);
        window.setValue((Attribute) this.attributes.get(attIndex++), nonConsHydro / nonConserved);

        window.setValue((Attribute) this.attributes.get(attIndex++), consPCharged / conserved);
        window.setValue((Attribute) this.attributes.get(attIndex++), nonConsPCharged / nonConserved);

        window.setValue((Attribute) this.attributes.get(attIndex++), consNCharged / conserved);
        window.setValue((Attribute) this.attributes.get(attIndex++), nonConsNCharged / nonConserved);

        window.setValue((Attribute) this.attributes.get(attIndex++), consPolar / conserved);
        window.setValue((Attribute) this.attributes.get(attIndex++), nonConsPolar / nonConserved);

        if (nTermDist > 40) {
            nTermDist = 4;
        } else if (nTermDist > 30) {
            nTermDist = 3;
        } else if (nTermDist > 20) {
            nTermDist = 2;
        } else if (nTermDist > 10) {
            nTermDist = 1;
        } else {
            nTermDist = 0;
        }

        if (cTermDist > 40) {
            cTermDist = 4;
        } else if (cTermDist > 30) {
            cTermDist = 3;
        } else if (cTermDist > 20) {
            cTermDist = 2;
        } else if (cTermDist > 10) {
            cTermDist = 1;
        } else {
            cTermDist = 0;
        }

        if (globalLenght > 240) {
            globalLenght = 4;
        } else if (globalLenght > 180) {
            globalLenght = 3;
        } else if (globalLenght > 120) {
            globalLenght = 2;
        } else if (globalLenght > 60) {
            globalLenght = 1;
        } else {
            globalLenght = 0;
        }

        window.setValue((Attribute) this.attributes.get(attIndex++), nTermDist);
        window.setValue((Attribute) this.attributes.get(attIndex++), cTermDist);
        window.setValue((Attribute) this.attributes.get(attIndex++), globalLenght);

        for (int i = 0; i < 20; ++i) {
            window.setValue((Attribute) this.attributes.get(attIndex++), this.globalConsAa[i]);
            window.setValue((Attribute) this.attributes.get(attIndex++), this.globalNonConsAa[i]);
        }

        return window;
    }

    /**
     * Trains the Weka Classifer.
     */
    public void trainClassifier() {
        try {
            RandomForest classifier = new weka.classifiers.trees.RandomForest();
            Instances data = this.dataset;

            if (data.classIndex() == -1) {
                data.setClassIndex(data.numAttributes() - 1);
            }

            data.randomize(new Random(data.size()));

            String[] optClassifier = weka.core.Utils.splitOptions("-I 100 -K 9 -S 1 -num-slots 3");

            classifier.setOptions(optClassifier);
            classifier.setSeed(data.size());

            classifier.buildClassifier(data);

            this.classifier = classifier;
            this.isTrained = true;
        } catch (Exception e) {
            ErrorUtils.printError(HelixIndexer.class, "Training failed", e);
        }
    }

    /**
     * Saves the Weka arff data and model into files.
     * Extensions: Arff -> .arff; Model -> .[model|model.gz]
     * 
     * @param filename
     */
    public void dumpModelToFile(String filename) {
        //if classifier is still untrained, train it!
        if (!this.isTrained) {
            this.trainClassifier();
        }

        try {
            ArffSaver saver = new ArffSaver();

            saver.setInstances(this.dataset);
            saver.setFile(new File(filename + ".arff"));
            saver.writeBatch();

            weka.core.SerializationHelper.write(filename + ".model", this.classifier);

            ModelHandler.saveGZip(this.classifier, filename + ".model.gz");
        } catch (Exception e) {
            ErrorUtils.printError(HelixIndexer.class,
                    "Failed to write model file(s) " + filename + ".[arrf|model|model.gz]", e);
        }
    }

    /**
     * Loads a previously saved model.
     * 
     * @param filename
     */
    public void loadModelFromFile(String filename) {
        try {
            File model = new File(filename + ".model");
            File modelGzip = new File(filename + ".model.gz");

            if (model.exists()) {
                this.classifier = (Classifier) weka.core.SerializationHelper.read(filename + ".model");
            } else if (modelGzip.exists()) {
                this.classifier = ModelHandler.loadGZip(filename + ".model.gz");
            } else {
                throw new FileNotFoundException(filename + ".[model|model.gz]");
            }
        } catch (Exception e) {
            ErrorUtils.printError(HelixIndexer.class,
                    "Failed to load model file(s) " + filename + ".[model|model.gz]", e);
        }
    }
}