Java examples for Machine Learning AI:weka
Use k-nearest neighbors search via weka
import java.util.Arrays; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; public class KNN { static Instances data; static int k; static int classIndex; static boolean printOn; /**/*from w w w. j a v a 2s . c o m*/ * @param args * @throws Exception */ public static void main(String[] args) { String file = null; DataSource source; try { file = args[0]; source = new DataSource(file); data = source.getDataSet(); } catch (Exception e) { System.err.println("Cannot read from file " + file); return; } try { k = Integer.parseInt(args[1]); } catch (Exception e) { // Default 3, if k is unset or set to invalid value. k = 3; } try { // args[2] is the class name. classIndex = data.attribute(args[2]).index(); } catch (Exception e) { // Default class index. classIndex = data.numAttributes() - 1; if (file.equals("autos.arff")) classIndex = data.numAttributes() - 2; } try { String p = args[3]; printOn = false; } catch (Exception e) { printOn = true; } if (data.classIndex() < 0) data.setClassIndex(classIndex); normalize(); doKNN(); } /** * Normalize all numeric attributes to [0, 1]. * Also delete the instances with missing attributes. */ private static void normalize() { // Do normalization to each attribute. for (int attIndex = 0; attIndex < data.numAttributes(); attIndex++) { // Delete the instances with missing value of this attribute. data.deleteWithMissing(attIndex); if (data.attribute(attIndex).isNominal()) continue; if (attIndex == classIndex) continue; // Normalize non-class and non-nominal attributes. double max = data.instance(0).value(attIndex); double min = max; // Find the max and min value of this attribute in the data set. for (int insIndex = 1; insIndex < data.numInstances(); insIndex++) { double value = data.instance(insIndex).value(attIndex); if (max < value) max = value; if (min > value) min = value; } //System.out.println("max="+max+",\tmin="+min+",\t"+data.attribute(attIndex).name()); if (max == min) // No need to normalize if the value of this attribute is a constant. continue; // Normalize the value of this attribute in each instance to [0, 1]. for (int insIndex = 0; insIndex < data.numInstances(); insIndex++) { double value = data.instance(insIndex).value(attIndex); double value_nm = (value - min) / (max - min); data.instance(insIndex).setValue(attIndex, value_nm); } } //System.out.println(data); } /** * */ private static void doKNN() { int testIndex; int numInstances = data.numInstances(); int numClasses = data.numClasses(); int numErrors = 0; // for nonimal prediction(classification) double[] errRate = new double[numInstances]; // for numeric prediction boolean isNominal = data.classAttribute().isNominal(); boolean isNumeric = data.classAttribute().isNumeric(); // Leave One Out Cross Validation. for (testIndex = 0; testIndex < numInstances; testIndex++) { if (printOn) System.out.printf("Instance %4d for testing.\t", testIndex); // Compute the distance to every instance in the data set // except the test instance itself. int index = 0; double[] distanceTo = new double[numInstances]; for (index = 0; index < numInstances; index++) { if (index == testIndex) continue; distanceTo[index] = computeDistance(index, testIndex); } // Distance to myself is the largest. distanceTo[testIndex] = Double.MAX_VALUE; // Find the indexes of the k nearest neighbours. int[] nearestNbour = new int[k]; double[] sortedDist = new double[numInstances]; System.arraycopy(distanceTo, 0, sortedDist, 0, numInstances); Arrays.sort(sortedDist); for (int i = 0; i < k; i++) { if (i < k - 1 && sortedDist[i] == sortedDist[i + 1]) continue; for (index = 0; index < numInstances; index++) { if (distanceTo[index] == sortedDist[i]) { nearestNbour[i] = index; if ((++i) == k) break; } } } if (isNominal) { // Each nearest neighbour gives a vote to its class value. String[] classvalue = new String[numClasses]; int[] vote = new int[numClasses]; for (int i = 0; i < numClasses; i++) { classvalue[i] = data.classAttribute().value(i); vote[i] = 0; } for (int j = 0; j < k; j++) { String thisclass = data.instance(nearestNbour[j]) .stringValue(classIndex); int i; for (i = 0; i < numClasses; i++) if (classvalue[i].equals(thisclass)) break; vote[i]++; } // Find the most-voted class value as the prediction. int maxVote = 0; for (int i = 0; i < numClasses; i++) { if (maxVote < vote[i]) maxVote = vote[i]; } String prediction = "neverseethis"; for (int i = 0; i < numClasses; i++) { if (vote[i] == maxVote) { prediction = classvalue[i]; break; } } String target = data.instance(testIndex).stringValue( classIndex); boolean correct = false; if (prediction.equals(target)) correct = true; else numErrors++; if (printOn) System.out.println("prediction=" + prediction + ",\ttarget=" + target + ",\t" + correct); } if (isNumeric) { double prediction, target; double[] nbClass = new double[k]; // class values of the nearest neighbours. double nbClassSum = 0; for (int i = 0; i < k; i++) { nbClass[i] = data.instance(nearestNbour[i]).value( classIndex); nbClassSum += nbClass[i]; } prediction = nbClassSum / k; target = data.instance(testIndex).value(classIndex); //TODO check error rate measure. errRate[testIndex] = Math.abs((prediction - target) / target); if (printOn) System.out .printf("prediction=%.3f,\ttarget=%.3f,\terrRate=%.4f\n", prediction, target, errRate[testIndex]); } // end of if(isNumeric) } // end of for(testIndex) // Print LOOCV evaluation results. System.out.print(" LOOCV evaluation result:"); System.out.println("algorithm:\t\t" + k + " Nearest Neighbour"); System.out.println("relation:\t\t" + data.relationName()); System.out.println("class attribute:\t" + data.classAttribute().name()); System.out.print("class type:\t\t"); double errorRate; int numTests = numInstances; // since it is LOOCV if (isNominal) { errorRate = (double) numErrors / (double) numTests; System.out.println("Nominal"); System.out.println("Number of errors:\t" + numErrors + "\nNumber of tests:\t" + numTests); System.out.println("Error Rate:\t\t" + errorRate); } if (isNumeric) { double errorRateSum = 0; for (int i = 0; i < numTests; i++) errorRateSum += errRate[i]; errorRate = errorRateSum / (double) numTests; System.out.println("Numeric"); System.out.println("Number of tests:\t" + numTests); System.out.println("Average Error Rate:\t" + errorRate); } } // end of doKNN() private static double computeDistance(int ins1, int ins2) { // Manhattan distance //TODO other distance? int numAtts = data.numAttributes(); double distance = 0; for (int attIndex = 0; attIndex < numAtts; attIndex++) { if (attIndex == classIndex) continue; if (data.attribute(attIndex).isNominal()) { if (!data.instance(ins1).stringValue(attIndex) .equals(data.instance(ins2).stringValue(attIndex))) { // Distance between two different nominal value is 1. distance += 1; continue; } } // Else, the attributes is Numeric. distance += Math.abs(data.instance(ins1).value(attIndex) - data.instance(ins2).value(attIndex)); } return distance; } }