List of usage examples for weka.core Instances randomize
public void randomize(Random random)
From source file:uzholdem.classifier.OnlineMultilayerPerceptron.java
License:Open Source License
public void trainModel(Instances aInstances, int numIterations) throws Exception { // setup m_instances if (this.m_instances == null) { this.m_instances = new Instances(aInstances, 0, aInstances.size()); }/*from ww w . j a v a 2 s .com*/ /////////// if (m_useNomToBin) { if (this.m_nominalToBinaryFilter == null) { m_nominalToBinaryFilter = new NominalToBinary(); try { m_nominalToBinaryFilter.setInputFormat(m_instances); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); return; } } aInstances = Filter.useFilter(aInstances, m_nominalToBinaryFilter); } Instances epochInstances = new Instances(aInstances); epochInstances.randomize(new Random()); Instances valSet = new Instances(aInstances, (int) (aInstances.size() * 0.3)); for (int i = 0; i < valSet.size(); i++) { valSet.add(epochInstances.instance(0)); epochInstances.delete(0); } m_instances = epochInstances; double right = 0; double driftOff = 0; double lastRight = Double.POSITIVE_INFINITY; double bestError = Double.POSITIVE_INFINITY; double tempRate; double totalWeight = 0; double totalValWeight = 0; double origRate = m_learningRate; //only used for when reset int numInVal = valSet.numInstances(); for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; for (int noa = 1; noa < 50 + 1; noa++) { right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating (and training occurs, for the //training set resetNetwork(); calculateOutputs(); tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= noa; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { m_instances = null; throw new Exception("Network cannot train. Try restarting with a" + " smaller learning rate."); } ////////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { //this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); //note 'right' could be calculated here just using //the calculate output values. This would be faster. //be less modular } } if (right < lastRight) { if (right < bestError) { bestError = right; // save the network weights at this point for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].saveWeights(); } driftOff = 0; } } else { driftOff++; } lastRight = right; if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].restoreWeights(); } m_accepted = true; } right /= totalValWeight; } m_epoch = noa; m_error = right; //shows what the neuralnet is upto if a gui exists. if (m_accepted) { m_instances = new Instances(m_instances, 0); return; } } }