List of usage examples for weka.core Instances numInstances
publicint numInstances()
From source file:gyc.UnderOverBoostM1.java
License:Open Source License
/** * Boosting method. Boosts using resampling * * @param data the training data to be used for generating the * boosted classifier.// w w w . j a va2 s. co m * @throws Exception if the classifier could not be built successfully */ protected void buildClassifierUsingResampling(Instances data) throws Exception { Instances trainData, sample, training; double epsilon, reweight, sumProbs; Evaluation evaluation; int numInstances = data.numInstances(); Random randomInstance = new Random(m_Seed); int resamplingIterations = 0; // Initialize data m_Betas = new double[m_Classifiers.length]; m_NumIterationsPerformed = 0; // Create a copy of the data so that when the weights are diddled // with it doesn't mess up the weights for anyone else training = new Instances(data, 0, numInstances); sumProbs = training.sumOfWeights(); for (int i = 0; i < training.numInstances(); i++) { training.instance(i).setWeight(training.instance(i).weight() / sumProbs); } // Do boostrap iterations int b = 10; for (m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++) { if (m_Debug) { System.err.println("Training classifier " + (m_NumIterationsPerformed + 1)); } // Select instances to train the classifier on if (m_WeightThreshold < 100) { trainData = selectWeightQuantile(training, (double) m_WeightThreshold / 100); } else { trainData = new Instances(training); } // Resample resamplingIterations = 0; double[] weights = new double[trainData.numInstances()]; for (int i = 0; i < weights.length; i++) { weights[i] = trainData.instance(i).weight(); } do { sample = trainData.resampleWithWeights(randomInstance, weights); // int classNum[] = sample.attributeStats(sample.classIndex()).nominalCounts; int minC, nMin = classNum[0]; int majC, nMaj = classNum[1]; if (nMin < nMaj) { minC = 0; majC = 1; } else { minC = 1; majC = 0; nMin = classNum[1]; nMaj = classNum[0]; } //System.out.println("minC="+nMin+"; majC="+nMaj); /* * balance the data which boosting generate for training base classifier */ //System.out.println("before:"+classNum[0]+"-"+classNum[1]); double pb = 100.0 * (nMin + nMaj) / 2 / nMaj; /* if (m_NumIterationsPerformed + 1 > (m_Classifiers.length / 10)) b += 10; (b% * Nmaj) instances are taken from each class */ Instances sampleData = randomSampling(sample, majC, minC, (int) pb, randomInstance); //classNum =sampleData.attributeStats(sampleData.classIndex()).nominalCounts; //System.out.println("after:"+classNum[0]+"-"+classNum[1]); // Build and evaluate classifier m_Classifiers[m_NumIterationsPerformed].buildClassifier(sampleData); evaluation = new Evaluation(data); evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], training); epsilon = evaluation.errorRate(); resamplingIterations++; } while (Utils.eq(epsilon, 0) && (resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS)); // Stop if error too big or 0 if (Utils.grOrEq(epsilon, 0.5) || Utils.eq(epsilon, 0)) { if (m_NumIterationsPerformed == 0) { m_NumIterationsPerformed = 1; // If we're the first we have to to use it } break; } // Determine the weight to assign to this model m_Betas[m_NumIterationsPerformed] = Math.log((1 - epsilon) / epsilon); reweight = (1 - epsilon) / epsilon; if (m_Debug) { System.err.println("\terror rate = " + epsilon + " beta = " + m_Betas[m_NumIterationsPerformed]); } // Update instance weights setWeights(training, reweight); } }
From source file:Helper.CustomFilter.java
private Instances toRange(Instances structure, int index) throws Exception { Attribute attr = structure.attribute(index); Attribute classlabel = structure.attribute(structure.numAttributes() - 1); String label = structure.instance(0).stringValue(classlabel); double threshold = structure.instance(0).value(index); for (int i = 0; i < structure.numInstances(); i++) { if (!structure.instance(i).stringValue(classlabel).equals(label)) { label = structure.instance(i).stringValue(classlabel); threshold = structure.instance(i).value(index); }// www.j ava 2s . c o m structure.instance(i).setValue(attr, threshold); } return structure; }
From source file:hr.irb.fastRandomForest.NakedFastRfBagging.java
License:Open Source License
/** * Bagging method. Produces DataCache objects with bootstrap samples of the * original data, and feeds them to the base classifier (which can only be a * FastRandomTree).//from www.ja v a2s . c o m * * @param data * The training set to be used for generating the bagged * classifier. * @param numThreads * The number of simultaneous threads to use for computation. * Pass zero (0) for autodetection. * @param motherForest * A reference to the FastRandomForest object that invoked this. * * @throws Exception * if the classifier could not be built successfully */ public void buildClassifier(Instances data, final int numThreads, final NakedFastRandomForest motherForest) throws Exception { // can classifier handle the vals? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (!(m_Classifier instanceof NakedFastRandomTree)) throw new IllegalArgumentException( "The NakedFastRfBagging class accepts " + "only NakedFastRandomTree as its base classifier."); /* * We fill the m_Classifiers array by creating lots of trees with new() * because this is much faster than using serialization to deep-copy the * one tree in m_Classifier - this is what the * super.buildClassifier(data) normally does. */ m_Classifiers = new Classifier[m_NumIterations]; for (int i = 0; i < m_Classifiers.length; i++) { final NakedFastRandomTree curTree = new NakedFastRandomTree(); // all parameters for training will be looked up in the motherForest // (maxDepth, k_Value) curTree.m_MotherForest = motherForest; // 0.99: reference to these arrays will get passed down all nodes so // the array can be re-used // 0.99: this array is of size two as now all splits are binary - // even categorical ones curTree.tempProps = new double[2]; curTree.tempDists = new double[2][]; curTree.tempDists[0] = new double[data.numClasses()]; curTree.tempDists[1] = new double[data.numClasses()]; curTree.tempDistsOther = new double[2][]; curTree.tempDistsOther[0] = new double[data.numClasses()]; curTree.tempDistsOther[1] = new double[data.numClasses()]; m_Classifiers[i] = curTree; } // this was SLOW.. takes approx 1/2 time as training the forest // afterwards (!!!) // super.buildClassifier(data); if (m_CalcOutOfBag && (m_BagSizePercent != 100)) { throw new IllegalArgumentException( "Bag size needs to be 100% if " + "out-of-bag error is to be calculated!"); } // sorting is performed inside this constructor final DataCache myData = new DataCache(data); final int bagSize = data.numInstances() * m_BagSizePercent / 100; final Random random = new Random(m_Seed); final boolean[][] inBag = new boolean[m_Classifiers.length][]; // thread management final ExecutorService threadPool = Executors .newFixedThreadPool(numThreads > 0 ? numThreads : Runtime.getRuntime().availableProcessors()); final List<Future<?>> futures = new ArrayList<Future<?>>(m_Classifiers.length); try { for (int treeIdx = 0; treeIdx < m_Classifiers.length; treeIdx++) { // create the in-bag dataset (and be sure to remember what's in // bag) // for computing the out-of-bag error later final DataCache bagData = myData.resample(bagSize, random); bagData.reusableRandomGenerator = bagData.getRandomNumberGenerator(random.nextInt()); inBag[treeIdx] = bagData.inBag; // store later for OOB error // calculation // build the classifier if (m_Classifiers[treeIdx] instanceof NakedFastRandomTree) { final FastRandomTree aTree = (FastRandomTree) m_Classifiers[treeIdx]; aTree.data = bagData; final Future<?> future = threadPool.submit(aTree); futures.add(future); } else { throw new IllegalArgumentException("The FastRfBagging class accepts " + "only NakedFastRandomTree as its base classifier."); } } // make sure all trees have been trained before proceeding for (int treeIdx = 0; treeIdx < m_Classifiers.length; treeIdx++) { futures.get(treeIdx).get(); } // [jhostetler] 'm_FeatureImportances' and 'computeOOBError()' are // private, so we'll just not compute them. // calc OOB error? // if( getCalcOutOfBag() || getComputeImportances() ) { // // m_OutOfBagError = computeOOBError(data, inBag, threadPool); // m_OutOfBagError = computeOOBError( myData, inBag, threadPool ); // } // else { // m_OutOfBagError = 0; // } // // calc feature importances // m_FeatureImportances = null; // // m_FeatureNames = null; // if( getComputeImportances() ) { // m_FeatureImportances = new double[data.numAttributes()]; // // /m_FeatureNames = new String[data.numAttributes()]; // // Instances dataCopy = new Instances(data); //To scramble // // int[] permutation = // // FastRfUtils.randomPermutation(data.numInstances(), random); // for( int j = 0; j < data.numAttributes(); j++ ) { // if( j != data.classIndex() ) { // // double sError = // // computeOOBError(FastRfUtils.scramble(data, dataCopy, // // j, permutation), inBag, threadPool); // // double sError = computeOOBError(data, inBag, // // threadPool, j, 0); // final float[] unscrambled = myData.scrambleOneAttribute( j, // random ); // final double sError = computeOOBError( myData, inBag, // threadPool ); // myData.vals[j] = unscrambled; // restore the original // // state // m_FeatureImportances[j] = sError - m_OutOfBagError; // } // // m_FeatureNames[j] = data.attribute(j).name(); // } // } threadPool.shutdown(); } finally { threadPool.shutdownNow(); } }
From source file:hsa_jni.hsa_jni.EvaluatePeriodicHeldOutTestBatch.java
License:Open Source License
@Override protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { Classifier learner = (Classifier) getPreparedClassOption(this.learnerOption); InstanceStream stream = (InstanceStream) getPreparedClassOption(this.streamOption); ClassificationPerformanceEvaluator evaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption( this.evaluatorOption); learner.setModelContext(stream.getHeader()); long instancesProcessed = 0; LearningCurve learningCurve = new LearningCurve("evaluation instances"); File dumpFile = this.dumpFileOption.getFile(); PrintStream immediateResultStream = null; if (dumpFile != null) { try {// ww w. j a v a 2 s .c om if (dumpFile.exists()) { immediateResultStream = new PrintStream(new FileOutputStream(dumpFile, true), true); } else { immediateResultStream = new PrintStream(new FileOutputStream(dumpFile), true); } } catch (Exception ex) { throw new RuntimeException("Unable to open immediate result file: " + dumpFile, ex); } } boolean firstDump = true; InstanceStream testStream = null; int testSize = this.testSizeOption.getValue(); if (this.cacheTestOption.isSet()) { monitor.setCurrentActivity("Caching test examples...", -1.0); Instances testInstances = new Instances(stream.getHeader(), this.testSizeOption.getValue()); while (testInstances.numInstances() < testSize) { testInstances.add(stream.nextInstance()); if (testInstances.numInstances() % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { if (monitor.taskShouldAbort()) { return null; } monitor.setCurrentActivityFractionComplete( (double) testInstances.numInstances() / (double) (this.testSizeOption.getValue())); } } testStream = new CachedInstancesStream(testInstances); } else { //testStream = (InstanceStream) stream.copy(); testStream = stream; /*monitor.setCurrentActivity("Skipping test examples...", -1.0); for (int i = 0; i < testSize; i++) { stream.nextInstance(); }*/ } instancesProcessed = 0; TimingUtils.enablePreciseTiming(); double totalTrainTime = 0.0; while ((this.trainSizeOption.getValue() < 1 || instancesProcessed < this.trainSizeOption.getValue()) && stream.hasMoreInstances() == true) { monitor.setCurrentActivityDescription("Training..."); long instancesTarget = instancesProcessed + this.sampleFrequencyOption.getValue(); ArrayList<Instance> instanceCache = new ArrayList<Instance>(); long trainStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); double lastTrainTime = 0; while (instancesProcessed < instancesTarget && stream.hasMoreInstances() == true) { instanceCache.add(stream.nextInstance()); instancesProcessed++; if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { if (monitor.taskShouldAbort()) { return null; } monitor.setCurrentActivityFractionComplete( (double) (instancesProcessed) / (double) (this.trainSizeOption.getValue())); } if (instanceCache.size() % 1000 == 0) { trainStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); for (Instance inst : instanceCache) { learner.trainOnInstance(inst); } lastTrainTime += TimingUtils .nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - trainStartTime); instanceCache.clear(); } } trainStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); for (Instance inst : instanceCache) { learner.trainOnInstance(inst); } if (learner instanceof BatchClassifier) ((BatchClassifier) learner).commit(); lastTrainTime += TimingUtils .nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - trainStartTime); totalTrainTime += lastTrainTime; if (totalTrainTime > this.trainTimeOption.getValue()) { break; } if (this.cacheTestOption.isSet()) { testStream.restart(); } evaluator.reset(); long testInstancesProcessed = 0; monitor.setCurrentActivityDescription("Testing (after " + StringUtils.doubleToString( ((double) (instancesProcessed) / (double) (this.trainSizeOption.getValue()) * 100.0), 2) + "% training)..."); long testStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); int instCount = 0; for (instCount = 0; instCount < testSize; instCount++) { if (stream.hasMoreInstances() == false) { break; } Instance testInst = (Instance) testStream.nextInstance().copy(); double trueClass = testInst.classValue(); testInst.setClassMissing(); double[] prediction = learner.getVotesForInstance(testInst); testInst.setClassValue(trueClass); evaluator.addResult(testInst, prediction); testInstancesProcessed++; if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { if (monitor.taskShouldAbort()) { return null; } monitor.setCurrentActivityFractionComplete( (double) testInstancesProcessed / (double) (testSize)); } } if (instCount != testSize) { break; } double testTime = TimingUtils .nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - testStartTime); List<Measurement> measurements = new ArrayList<Measurement>(); measurements.add(new Measurement("evaluation instances", instancesProcessed)); measurements.add(new Measurement("total train time", totalTrainTime)); measurements.add(new Measurement("total train speed", instancesProcessed / totalTrainTime)); measurements.add(new Measurement("last train time", lastTrainTime)); measurements.add( new Measurement("last train speed", this.sampleFrequencyOption.getValue() / lastTrainTime)); measurements.add(new Measurement("test time", testTime)); measurements.add(new Measurement("test speed", this.testSizeOption.getValue() / testTime)); Measurement[] performanceMeasurements = evaluator.getPerformanceMeasurements(); for (Measurement measurement : performanceMeasurements) { measurements.add(measurement); } Measurement[] modelMeasurements = learner.getModelMeasurements(); for (Measurement measurement : modelMeasurements) { measurements.add(measurement); } learningCurve.insertEntry( new LearningEvaluation(measurements.toArray(new Measurement[measurements.size()]))); if (immediateResultStream != null) { if (firstDump) { immediateResultStream.println(learningCurve.headerToString()); firstDump = false; } immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); immediateResultStream.flush(); } if (monitor.resultPreviewRequested()) { monitor.setLatestResultPreview(learningCurve.copy()); } // if (learner instanceof HoeffdingTree // || learner instanceof HoeffdingOptionTree) { // int numActiveNodes = (int) Measurement.getMeasurementNamed( // "active learning leaves", // modelMeasurements).getValue(); // // exit if tree frozen // if (numActiveNodes < 1) { // break; // } // int numNodes = (int) Measurement.getMeasurementNamed( // "tree size (nodes)", modelMeasurements) // .getValue(); // if (numNodes == lastNumNodes) { // noGrowthCount++; // } else { // noGrowthCount = 0; // } // lastNumNodes = numNodes; // } else if (learner instanceof OzaBoost || learner instanceof // OzaBag) { // double numActiveNodes = Measurement.getMeasurementNamed( // "[avg] active learning leaves", // modelMeasurements).getValue(); // // exit if all trees frozen // if (numActiveNodes == 0.0) { // break; // } // int numNodes = (int) (Measurement.getMeasurementNamed( // "[avg] tree size (nodes)", // learner.getModelMeasurements()).getValue() * Measurement // .getMeasurementNamed("ensemble size", // modelMeasurements).getValue()); // if (numNodes == lastNumNodes) { // noGrowthCount++; // } else { // noGrowthCount = 0; // } // lastNumNodes = numNodes; // } } if (immediateResultStream != null) { immediateResultStream.close(); } return learningCurve; }
From source file:hurtowniedanych.FXMLController.java
public void trainAndTestKNN() throws FileNotFoundException, IOException, Exception { InstanceQuery instanceQuery = new InstanceQuery(); instanceQuery.setUsername("postgres"); instanceQuery.setPassword("szupek"); instanceQuery.setCustomPropsFile(new File("./src/data/DatabaseUtils.props")); // Wskazanie pliku z ustawieniami dla PostgreSQL String query = "select ks.wydawnictwo,ks.gatunek, kl.mia-sto\n" + "from zakupy z,ksiazki ks,klienci kl\n" + "where ks.id_ksiazka=z.id_ksiazka and kl.id_klient=z.id_klient"; instanceQuery.setQuery(query);/* w ww.j av a2 s .c o m*/ Instances data = instanceQuery.retrieveInstances(); data.setClassIndex(data.numAttributes() - 1); data.randomize(new Random()); double percent = 70.0; int trainSize = (int) Math.round(data.numInstances() * percent / 100); int testSize = data.numInstances() - trainSize; Instances trainData = new Instances(data, 0, trainSize); Instances testData = new Instances(data, trainSize, testSize); int lSasiadow = Integer.parseInt(textFieldKnn.getText()); System.out.println(lSasiadow); IBk ibk = new IBk(lSasiadow); // Ustawienie odleglosci EuclideanDistance euclidean = new EuclideanDistance(); // euklidesowej ManhattanDistance manhatan = new ManhattanDistance(); // miejska LinearNNSearch linearNN = new LinearNNSearch(); if (comboboxOdleglosc.getSelectionModel().getSelectedItem().equals("Manhatan")) { linearNN.setDistanceFunction(manhatan); } else { linearNN.setDistanceFunction(euclidean); } ibk.setNearestNeighbourSearchAlgorithm(linearNN); // ustawienie sposobu szukania sasiadow // Tworzenie klasyfikatora ibk.buildClassifier(trainData); Evaluation eval = new Evaluation(trainData); eval.evaluateModel(ibk, testData); spr.setVisible(true); labelKnn.setVisible(true); labelOdleglosc.setVisible(true); labelKnn.setText(textFieldKnn.getText()); labelOdleglosc.setText(comboboxOdleglosc.getSelectionModel().getSelectedItem().toString()); spr.setText(eval.toSummaryString("Wynik:", true)); }
From source file:ia02classificacao.IA02Classificacao.java
/** * @param args the command line arguments *//*from www .j a v a2 s .com*/ public static void main(String[] args) throws Exception { // abre o banco de dados arff e mostra a quantidade de instancias (linhas) DataSource arquivo = new DataSource("data/zoo.arff"); Instances dados = arquivo.getDataSet(); System.out.println("Instancias lidas: " + dados.numInstances()); // FILTER: remove o atributo nome do animal da classificao String[] parametros = new String[] { "-R", "1" }; Remove filtro = new Remove(); filtro.setOptions(parametros); filtro.setInputFormat(dados); dados = Filter.useFilter(dados, filtro); AttributeSelection selAtributo = new AttributeSelection(); InfoGainAttributeEval avaliador = new InfoGainAttributeEval(); Ranker busca = new Ranker(); selAtributo.setEvaluator(avaliador); selAtributo.setSearch(busca); selAtributo.SelectAttributes(dados); int[] indices = selAtributo.selectedAttributes(); System.out.println("Selected attributes: " + Utils.arrayToString(indices)); // Usa o algoritimo J48 e mostra a classificao dos dados em forma textual String[] opcoes = new String[1]; opcoes[0] = "-U"; J48 arvore = new J48(); arvore.setOptions(opcoes); arvore.buildClassifier(dados); System.out.println(arvore); // Usa o algoritimo J48 e mostra a classificao de dados em forma grafica /* TreeVisualizer tv = new TreeVisualizer(null, arvore.graph(), new PlaceNode2()); JFrame frame = new javax.swing.JFrame("?rvore de Conhecimento"); frame.setSize(800,500); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); frame.getContentPane().add(tv); frame.setVisible(true); tv.fitToScreen(); */ /* * Classificao de novos dados */ System.out.println("\n\nCLASSIFICAO DE NOVOS DADOS"); // criar atributos double[] vals = new double[dados.numAttributes()]; vals[0] = 1.0; // hair vals[1] = 0.0; // feathers vals[2] = 0.0; // eggs vals[3] = 1.0; // milk vals[4] = 1.0; // airborne vals[5] = 0.0; // aquatic vals[6] = 0.0; // predator vals[7] = 1.0; // toothed vals[8] = 1.0; // backbone vals[9] = 1.0; // breathes vals[10] = 0.0; // venomous vals[11] = 0.0; // fins vals[12] = 4.0; // legs vals[13] = 1.0; // tail vals[14] = 1.0; // domestic vals[15] = 1.0; // catsize // Criar uma instncia baseada nestes atributos Instance meuUnicornio = new DenseInstance(1.0, vals); // Adicionar a instncia nos dados meuUnicornio.setDataset(dados); // Classificar esta nova instncia double label = arvore.classifyInstance(meuUnicornio); // Imprimir o resultado da classificao System.out.println("Novo Animal: Unicrnio"); System.out.println("classificacao: " + dados.classAttribute().value((int) label)); /* * Avaliao e predio de erros de mtrica */ System.out.println("\n\nAVALIAO E PREDIO DE ERROS DE MTRICA"); Classifier cl = new J48(); Evaluation eval_roc = new Evaluation(dados); eval_roc.crossValidateModel(cl, dados, 10, new Random(1), new Object[] {}); System.out.println(eval_roc.toSummaryString()); /* * Matriz de confuso */ System.out.println("\n\nMATRIZ DE CONFUSO"); double[][] confusionMatrix = eval_roc.confusionMatrix(); System.out.println(eval_roc.toMatrixString()); }
From source file:id3.MyID3.java
/** * Algoritma pohon keputusan//w w w .j a v a 2 s. c o m * @param instances data train * @param attributes remaining attributes * @throws Exception */ public void buildMyID3(Instances instances, ArrayList<Attribute> attributes) throws Exception { // Check if no instances have reached this node. if (instances.numInstances() == 0) { classAttribute = null; classLabel = Instance.missingValue(); classDistributionAmongInstances = new double[instances.numClasses()]; return; } // Check if all instances only contain one class label if (computeEntropy(instances) == 0) { currentAttribute = null; classDistributionAmongInstances = classDistribution(instances); // Labelling process at node for (int i = 0; i < classDistributionAmongInstances.length; i++) { if (classDistributionAmongInstances[i] > 0) { classLabel = i; break; } } classAttribute = instances.classAttribute(); Utils.normalize(classDistributionAmongInstances); } else { // Compute infogain for each attribute double[] infoGainAttribute = new double[instances.numAttributes()]; for (int i = 0; i < instances.numAttributes(); i++) { infoGainAttribute[i] = computeIG(instances, instances.attribute(i)); } // Choose attribute with maximum information gain int indexMaxInfoGain = 0; double maximumInfoGain = 0.0; for (int i = 0; i < (infoGainAttribute.length - 1); i++) { if (infoGainAttribute[i] > maximumInfoGain) { maximumInfoGain = infoGainAttribute[i]; indexMaxInfoGain = i; } } currentAttribute = instances.attribute(indexMaxInfoGain); // Delete current attribute from remaining attribute ArrayList<Attribute> remainingAttributes = attributes; if (!remainingAttributes.isEmpty()) { int indexAttributeDeleted = 0; for (int i = 0; i < remainingAttributes.size(); i++) { if (remainingAttributes.get(i).index() == currentAttribute.index()) { indexAttributeDeleted = i; } } remainingAttributes.remove(indexAttributeDeleted); } // Split instances based on currentAttribute (create branch new node) Instances[] instancesSplitBasedAttribute = splitData(instances, currentAttribute); subTree = new MyID3[currentAttribute.numValues()]; for (int i = 0; i < currentAttribute.numValues(); i++) { if (instancesSplitBasedAttribute[i].numInstances() == 0) { // Handle empty examples at nodes double[] currentClassDistribution = classDistribution(instances); classLabel = 0.0; double counterDistribution = 0.0; for (int j = 0; j < currentClassDistribution.length; j++) { if (currentClassDistribution[j] > counterDistribution) { classLabel = j; } } classAttribute = instances.classAttribute(); } else { subTree[i] = new MyID3(); subTree[i].buildMyID3(instancesSplitBasedAttribute[i], remainingAttributes); } } } }
From source file:id3.MyID3.java
/** * Algoritma untuk menghitung distribusi kelas * @param instances// www . j av a 2s. co m * @return distributionClass counter */ public double[] classDistribution(Instances instances) { // Compute class distribution counter from instances double[] distributionClass = new double[instances.numClasses()]; for (int i = 0; i < instances.numInstances(); i++) { distributionClass[(int) instances.instance(i).classValue()]++; } return distributionClass; }
From source file:id3.MyID3.java
/** * Menghitung information gain//from w w w .j av a2 s .c om * @param data instance * @param att atribut * @return hasil information gain */ public double computeIG(Instances data, Attribute att) { // Split instances based on attribute values Instances[] instancesSplitBasedAttribute = splitData(data, att); // Compute information gain based on instancesSplitBasedAttribute double entrophyOverall = computeEntropy(data); for (int i = 0; i < instancesSplitBasedAttribute.length; i++) { entrophyOverall -= ((double) instancesSplitBasedAttribute[i].numInstances() / (double) data.numInstances()) * computeEntropy(instancesSplitBasedAttribute[i]); } return entrophyOverall; }
From source file:id3.MyID3.java
/** * Menghitung entropy//w w w . ja v a 2 s.co m * @param data instance * @return hasil perhitungan entropy */ public double computeEntropy(Instances data) { // Compute class distribution counter from instances double[] distributionClass = classDistribution(data); // Compute entrophy from class distribution counter double entrophy = 0.0; for (int i = 0; i < distributionClass.length; i++) { double operanLog2 = distributionClass[i] / (double) data.numInstances(); if (operanLog2 != 0) { entrophy -= (distributionClass[i] / (double) data.numInstances()) * (Math.log(operanLog2) / Math.log(2)); } else { entrophy -= 0; } } return entrophy; }