Java examples for Machine Learning AI:weka
Cluster Bagging weka
import java.io.FileWriter; import java.util.ArrayList; import weka.classifiers.Evaluation; import weka.classifiers.evaluation.Prediction; import weka.classifiers.meta.Bagging; import weka.classifiers.trees.RandomForest; import weka.clusterers.SimpleKMeans; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Add; import weka.filters.unsupervised.instance.RemoveFrequentValues; import au.com.bytecode.opencsv.CSVWriter; public class Main { public static void main(String[] args) throws Exception { //from w ww . j ava 2 s . c o m Instances train = DataSource read("./train1.arff"); int cid1 = train.numAttributes() - 1; train.setClassIndex(cid1); Instances validation = DataSource read("./validation1.arff"); int cid2 = validation.numAttributes() - 1; validation.setClassIndex(cid2); Instances test = DataSource read("./test1.arff"); int cid3 = test.numAttributes() - 1; test.setClassIndex(cid3); //Remove fraud class instances RemoveFrequentValues remove = new RemoveFrequentValues(); remove.setInputFormat(train); remove.setAttributeIndex("last"); remove.setNumValues(1); Instances train_ok = Filter.useFilter(train, remove); int cid4 = train_ok.numAttributes() - 1; train_ok.setClassIndex(cid4); //Remove ok class instances RemoveFrequentValues remove1 = new RemoveFrequentValues(); remove1.setInputFormat(train); remove1.setAttributeIndex("last"); remove1.setNumValues(1); remove1.setUseLeastValues(true); Instances train_fraud = Filter.useFilter(train, remove1); int cid5 = train_fraud.numAttributes() - 1; train_fraud.setClassIndex(cid5); //remove class attribute for clustering weka.filters.unsupervised.attribute.Remove filter = new weka.filters.unsupervised.attribute.Remove(); filter.setAttributeIndices("" + (train_ok.classIndex() + 1)); filter.setInputFormat(train_ok); Instances dataClusterer = Filter.useFilter(train_ok, filter); //cluster using K-means SimpleKMeans cluster = new SimpleKMeans(); cluster.setNumClusters(146); cluster.buildClusterer(dataClusterer); train_ok = cluster.getClusterCentroids(); //Add deleted class attribute Add add_attribute = new Add(); add_attribute.setAttributeName("status"); add_attribute.setAttributeIndex("last"); add_attribute.setNominalLabels("0,1"); //SelectedTag value= //add_attribute.setAttributeType(value); add_attribute.setInputFormat(train_ok); train_ok = Filter.useFilter(train_ok, add_attribute); for (int i = 0; i < train_ok.numInstances(); i++) { train_ok.instance(i) setValue(train_ok.numAttributes() - 1, "0"); } int cid7 = train_ok.numAttributes() - 1; train_ok.setClassIndex(cid7); //combine train_ok and train_fraud for (int i = 0; i < train_fraud.numInstances(); i++) train_ok.add(train_fraud.instance(i)); train = train_ok; int cid6 = train.numAttributes() - 1; train.setClassIndex(cid6); //Bagging RF RandomForest rf = new RandomForest(); Bagging tree = new Bagging(); tree.setClassifier(rf); tree.buildClassifier(train); Evaluation eval = new Evaluation(train); eval.evaluateModel(tree, validation); ArrayList<Prediction> al = eval.predictions(); ArrayList<String[]> as = new ArrayList<String[]>(al.size()); for (int i = 0; i < al.size(); i++) { String[] s = new String[1]; s[0] = al.get(i).toString(); s[0] = s[0].substring(9, 11); as.add(s); } ArrayList<String[]> li = new ArrayList<String[]>(al.size()); li.addAll(as); String csv = "./output.csv"; CSVWriter writer = new CSVWriter(new FileWriter(csv)); writer.writeAll(li); writer.close(); } }