List of usage examples for weka.core Instances getRandomNumberGenerator
public Random getRandomNumberGenerator(long seed)
From source file:ai.BalancedRandomForest.java
License:GNU General Public License
/** * Build Balanced Random Forest// w w w. ja v a 2 s.co m */ public void buildClassifier(final Instances data) throws Exception { // If number of features is 0 then set it to log2 of M (number of attributes) if (numFeatures < 1) numFeatures = (int) Utils.log2(data.numAttributes()) + 1; // Check maximum number of random features if (numFeatures >= data.numAttributes()) numFeatures = data.numAttributes() - 1; // Initialize array of trees tree = new BalancedRandomTree[numTrees]; // total number of instances final int numInstances = data.numInstances(); // total number of classes final int numClasses = data.numClasses(); final ArrayList<Integer>[] indexSample = new ArrayList[numClasses]; for (int i = 0; i < numClasses; i++) indexSample[i] = new ArrayList<Integer>(); //System.out.println("numClasses = " + numClasses); // fill indexSample with the indices of each class for (int i = 0; i < numInstances; i++) { //System.out.println("data.get("+i+").classValue() = " + data.get(i).classValue()); indexSample[(int) data.get(i).classValue()].add(i); } final Random random = new Random(seed); // Executor service to run concurrent trees final ExecutorService exe = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); List<Future<BalancedRandomTree>> futures = new ArrayList<Future<BalancedRandomTree>>(numTrees); final boolean[][] inBag = new boolean[numTrees][numInstances]; try { for (int i = 0; i < numTrees; i++) { final ArrayList<Integer> bagIndices = new ArrayList<Integer>(); // Randomly select the indices in a balanced way for (int j = 0; j < numInstances; j++) { // Select first the class final int randomClass = random.nextInt(numClasses); // Select then a random sample of that class final int randomSample = random.nextInt(indexSample[randomClass].size()); bagIndices.add(indexSample[randomClass].get(randomSample)); inBag[i][indexSample[randomClass].get(randomSample)] = true; } // Create random tree final Splitter splitter = new Splitter( new GiniFunction(numFeatures, data.getRandomNumberGenerator(random.nextInt()))); futures.add(exe.submit(new Callable<BalancedRandomTree>() { public BalancedRandomTree call() { return new BalancedRandomTree(data, bagIndices, splitter); } })); } // Grab all trained trees before proceeding for (int treeIdx = 0; treeIdx < numTrees; treeIdx++) tree[treeIdx] = futures.get(treeIdx).get(); // Calculate out of bag error final boolean numeric = data.classAttribute().isNumeric(); List<Future<Double>> votes = new ArrayList<Future<Double>>(data.numInstances()); for (int i = 0; i < data.numInstances(); i++) { VotesCollector aCollector = new VotesCollector(tree, i, data, inBag); votes.add(exe.submit(aCollector)); } double outOfBagCount = 0.0; double errorSum = 0.0; for (int i = 0; i < data.numInstances(); i++) { double vote = votes.get(i).get(); // error for instance outOfBagCount += data.instance(i).weight(); if (numeric) { errorSum += StrictMath.abs(vote - data.instance(i).classValue()) * data.instance(i).weight(); } else { if (vote != data.instance(i).classValue()) errorSum += data.instance(i).weight(); } } outOfBagError = errorSum / outOfBagCount; } catch (Exception ex) { ex.printStackTrace(); } finally { exe.shutdownNow(); } }
From source file:br.com.ufu.lsi.rebfnetwork.RBFModel.java
License:Open Source License
/** * Method used to pre-process the data, perform clustering, and * set the initial parameter vector./*w w w . j av a 2 s.c o m*/ */ protected Instances initializeClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); data = new Instances(data); data.deleteWithMissingClass(); // Make sure data is shuffled Random random = new Random(m_Seed); if (data.numInstances() > 2) { random = data.getRandomNumberGenerator(m_Seed); } data.randomize(random); double y0 = data.instance(0).classValue(); // This stuff is not relevant in classification case int index = 1; while (index < data.numInstances() && data.instance(index).classValue() == y0) { index++; } if (index == data.numInstances()) { // degenerate case, all class values are equal // we don't want to deal with this, too much hassle throw new Exception("All class values are the same. At least two class values should be different"); } double y1 = data.instance(index).classValue(); // Replace missing values m_ReplaceMissingValues = new ReplaceMissingValues(); m_ReplaceMissingValues.setInputFormat(data); data = Filter.useFilter(data, m_ReplaceMissingValues); // Remove useless attributes m_AttFilter = new RemoveUseless(); m_AttFilter.setInputFormat(data); data = Filter.useFilter(data, m_AttFilter); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data after removing useless attributes!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return data; } else { m_ZeroR = null; } // Transform attributes m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(data); data = Filter.useFilter(data, m_NominalToBinary); m_Filter = new Normalize(); ((Normalize) m_Filter).setIgnoreClass(true); m_Filter.setInputFormat(data); data = Filter.useFilter(data, m_Filter); double z0 = data.instance(0).classValue(); // This stuff is not relevant in classification case double z1 = data.instance(index).classValue(); m_x1 = (y0 - y1) / (z0 - z1); // no division by zero, since y0 != y1 guaranteed => z0 != z1 ??? m_x0 = (y0 - m_x1 * z0); // = y1 - m_x1 * z1 m_classIndex = data.classIndex(); m_numClasses = data.numClasses(); m_numAttributes = data.numAttributes(); // Run k-means SimpleKMeans skm = new SimpleKMeans(); skm.setMaxIterations(10000); skm.setNumClusters(m_numUnits); Remove rm = new Remove(); data.setClassIndex(-1); rm.setAttributeIndices((m_classIndex + 1) + ""); rm.setInputFormat(data); Instances dataRemoved = Filter.useFilter(data, rm); data.setClassIndex(m_classIndex); skm.buildClusterer(dataRemoved); Instances centers = skm.getClusterCentroids(); if (centers.numInstances() < m_numUnits) { m_numUnits = centers.numInstances(); } // Set up arrays OFFSET_WEIGHTS = 0; if (m_useAttributeWeights) { OFFSET_ATTRIBUTE_WEIGHTS = (m_numUnits + 1) * m_numClasses; OFFSET_CENTERS = OFFSET_ATTRIBUTE_WEIGHTS + m_numAttributes; } else { OFFSET_ATTRIBUTE_WEIGHTS = -1; OFFSET_CENTERS = (m_numUnits + 1) * m_numClasses; } OFFSET_SCALES = OFFSET_CENTERS + m_numUnits * m_numAttributes; switch (m_scaleOptimizationOption) { case USE_GLOBAL_SCALE: m_RBFParameters = new double[OFFSET_SCALES + 1]; break; case USE_SCALE_PER_UNIT_AND_ATTRIBUTE: m_RBFParameters = new double[OFFSET_SCALES + m_numUnits * m_numAttributes]; break; default: m_RBFParameters = new double[OFFSET_SCALES + m_numUnits]; break; } // Set initial radius based on distance to nearest other basis function double maxMinDist = -1; for (int i = 0; i < centers.numInstances(); i++) { double minDist = Double.MAX_VALUE; for (int j = i + 1; j < centers.numInstances(); j++) { double dist = 0; for (int k = 0; k < centers.numAttributes(); k++) { if (k != centers.classIndex()) { double diff = centers.instance(i).value(k) - centers.instance(j).value(k); dist += diff * diff; } } if (dist < minDist) { minDist = dist; } } if ((minDist != Double.MAX_VALUE) && (minDist > maxMinDist)) { maxMinDist = minDist; } } // Initialize parameters if (m_scaleOptimizationOption == USE_GLOBAL_SCALE) { m_RBFParameters[OFFSET_SCALES] = Math.sqrt(maxMinDist); } for (int i = 0; i < m_numUnits; i++) { if (m_scaleOptimizationOption == USE_SCALE_PER_UNIT) { m_RBFParameters[OFFSET_SCALES + i] = Math.sqrt(maxMinDist); } int k = 0; for (int j = 0; j < m_numAttributes; j++) { if (k == centers.classIndex()) { k++; } if (j != data.classIndex()) { if (m_scaleOptimizationOption == USE_SCALE_PER_UNIT_AND_ATTRIBUTE) { m_RBFParameters[OFFSET_SCALES + (i * m_numAttributes + j)] = Math.sqrt(maxMinDist); } m_RBFParameters[OFFSET_CENTERS + (i * m_numAttributes) + j] = centers.instance(i).value(k); k++; } } } if (m_useAttributeWeights) { for (int j = 0; j < m_numAttributes; j++) { if (j != data.classIndex()) { m_RBFParameters[OFFSET_ATTRIBUTE_WEIGHTS + j] = 1.0; } } } initializeOutputLayer(random); return data; }
From source file:meka.classifiers.multilabel.MULAN.java
License:Open Source License
@Override public void buildClassifier(Instances instances) throws Exception { testCapabilities(instances);//from w w w. j a va 2 s.c o m long before = System.currentTimeMillis(); if (getDebug()) System.err.print(" moving target attributes to the beginning ... "); Random r = instances.getRandomNumberGenerator(0); String name = "temp_" + MLUtils.getDatasetName(instances) + "_" + r.nextLong() + ".arff"; System.err.println("Using temporary file: " + name); int L = instances.classIndex(); // rename attributes, because MULAN doesn't deal well with hypens etc for (int i = L; i < instances.numAttributes(); i++) { instances.renameAttribute(i, "a_" + i); } BufferedWriter writer = new BufferedWriter(new FileWriter(name)); m_InstancesTemplate = F.meka2mulan(new Instances(instances), L); writer.write(m_InstancesTemplate.toString()); writer.flush(); writer.close(); MultiLabelInstances train = new MultiLabelInstances(name, L); try { new File(name).delete(); } catch (Exception e) { System.err.println( "[Error] Failed to delete temporary file: " + name + ". You may want to delete it manually."); } if (getDebug()) System.out.println(" done "); long after = System.currentTimeMillis(); System.err.println("[Note] Discount " + ((after - before) / 1000.0) + " seconds from this build time"); m_InstancesTemplate = new Instances(train.getDataSet(), 0); System.out.println("CLASSIFIER " + m_Classifier); //m_InstancesTemplate.delete(); if (m_MethodString.equals("BR")) m_MULAN = new BinaryRelevance(m_Classifier); else if (m_MethodString.equals("LP")) m_MULAN = new LabelPowerset(m_Classifier); else if (m_MethodString.equals("CLR")) m_MULAN = new CalibratedLabelRanking(m_Classifier); else if (m_MethodString.equals("RAkEL1")) { m_MULAN = new RAkEL(new LabelPowerset(m_Classifier), 10, L / 2); System.out.println("m=10,k=" + (L / 2)); } else if (m_MethodString.equals("RAkEL2")) { m_MULAN = new RAkEL(new LabelPowerset(m_Classifier), 2 * L, 3); System.out.println("m=" + (L * 2) + ",k=3"); } else if (m_MethodString.equals("MLkNN")) m_MULAN = new MLkNN(10, 1.0); else if (m_MethodString.equals("IBLR_ML")) m_MULAN = new IBLR_ML(10); else if (m_MethodString.equals("BPMLL")) { //BPMLL is run withthe number of hidden units equal to 20% of the input units. m_MULAN = new BPMLL(); ((BPMLL) m_MULAN).setLearningRate(0.01); ((BPMLL) m_MULAN).setHiddenLayers(new int[] { 30 }); ((BPMLL) m_MULAN).setTrainingEpochs(100); } else if (m_MethodString.startsWith("HOMER")) { //Class m = Class.forName("HierarchyBuilder.Method.Random"); //Class w = Class.forName("mulan.classifier.LabelPowerset"); //Constructor c = new h.getConstructor(new Class[]{MultiLabelLearner.class, Integer.TYPE, HierarchyBuilder.Method.class}); //Object obj = h.newInstance(); String ops[] = m_MethodString.split("\\."); // number of clusters int n = 3; try { n = Integer.parseInt(ops[2]); } catch (Exception e) { System.err.println("[Warning] Could not parse number of clusters, using default: " + n); } // learner // @TODO use reflection here MultiLabelLearner mll = new LabelPowerset(m_Classifier); if (ops[3].equalsIgnoreCase("BinaryRelevance")) { mll = new BinaryRelevance(m_Classifier); } else if (ops[3].equalsIgnoreCase("ClassifierChain")) { mll = new ClassifierChain(m_Classifier); } else if (ops[3].equalsIgnoreCase("LabelPowerset")) { // already set } else { System.err.println( "[Warning] Did not recognise classifier type String, using default: LabelPowerset"); } if (getDebug()) { System.out.println("HOMER(" + mll + "," + n + "," + ops[1] + ")"); } m_MULAN = new HOMER(mll, n, HierarchyBuilder.Method.valueOf(ops[1])); } else throw new Exception("Could not find MULAN Classifier by that name: " + m_MethodString); m_MULAN.setDebug(getDebug()); m_MULAN.build(train); }
From source file:org.scripps.branch.classifier.ManualTree.java
License:Open Source License
/** * Builds classifier./* w w w . j av a2s . com*/ * * @param data * the data to train with * @throws Exception * if something goes wrong or the data doesn't fit */ @Override public void buildClassifier(Instances data) throws Exception { // Make sure K value is in range if (m_KValue > data.numAttributes() - 1) m_KValue = data.numAttributes() - 1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1; // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return; } else { m_ZeroR = null; } // Figure out appropriate datasets Instances train = null; Instances backfit = null; Random rand = data.getRandomNumberGenerator(m_randomSeed); if (m_NumFolds <= 0) { train = data; } else { data.randomize(rand); data.stratify(m_NumFolds); train = data.trainCV(m_NumFolds, 1, rand); backfit = data.testCV(m_NumFolds, 1); } //Set Default Instances for selection. setRequiredInst(data); // Create the attribute indices window int[] attIndicesWindow = new int[data.numAttributes() - 1]; int j = 0; for (int i = 0; i < attIndicesWindow.length; i++) { if (j == data.classIndex()) j++; // do not include the class attIndicesWindow[i] = j++; } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); classProbs[(int) inst.classValue()] += inst.weight(); } Instances requiredInstances = getRequiredInst(); // Build tree if (jsontree != null) { buildTree(train, classProbs, new Instances(data, 0), m_Debug, 0, jsontree, 0, m_distributionData, requiredInstances, listOfFc, cSetList, ccSer, d); } else { System.out.println("No json tree specified, failing to process tree"); } setRequiredInst(requiredInstances); // Backfit if required if (backfit != null) { backfitData(backfit); } }