naivebayes.NBTubesAI.java Source code

Java tutorial

Introduction

Here is the source code for naivebayes.NBTubesAI.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package naivebayes;

import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class NBTubesAI extends AbstractClassifier {

    //jumlah kelas pada dataset    

    //HashMap untuk menyimpan distribusi peluang NaiveBayes
    //Hashmap luar memetakan nama Attribut ke dalam hashmap bagian tengah
    //Hashmap tengah memetakan domain atribut menuju class dari atribut itu (misal yes/no)
    //Hashmap dalam memetakan kelas kedalam probabilitas yang bersesuaian
    protected HashMap<String, HashMap<String, HashMap<Double, Double>>> distribution;
    //HashMap untuk menyimpan distribusi jumlah setiap kelas dari seluruh instance
    protected HashMap<Double, Integer> classCount;
    //jumlah instance
    public static final long serialVersionUID = -6079756312492915625L;
    public int numInstance;
    protected Instances m_Instances;

    @Override
    public void buildClassifier(Instances data) throws Exception {
        distribution = new HashMap<>();
        classCount = new HashMap<>();

        data = new Instances(data);
        //Delete data tanpa kelas
        data.deleteWithMissingClass();
        //melakukan filter discretize untuk mengubah atribut menjadi nominal
        //menghitung jumlah instance
        m_Instances = new Instances(data);
        numInstance = data.numInstances();
        //Enumerasi seluruh atribut instances
        Enumeration<Attribute> enumAttr = m_Instances.enumerateAttributes();
        //Index attribut saat ini
        int attrIndex = 0;
        //Hashmap untuk menghitung jumlah kemunculan kelas yang bersesuaian

        for (int i = 0; i < m_Instances.classAttribute().numValues(); i++) {
            classCount.put(i + 0.0, 0);
        }
        Enumeration<Instance> forCount = m_Instances.enumerateInstances();
        while (forCount.hasMoreElements()) {
            Instance instCount = forCount.nextElement();
            classCount.put(instCount.classValue(), classCount.get(instCount.classValue()) + 1);
        }

        System.out.println("JMLAH KELAS:" + m_Instances.numClasses());
        System.out.println(classCount.toString());
        //Looping untuk seluruh atribut
        while (enumAttr.hasMoreElements()) {
            Attribute temp = enumAttr.nextElement();

            //nama attribute
            String attrName = temp.name();
            //Memasukkan kunci attrName
            if (distribution.get(attrName) == null) {
                distribution.put(attrName, new HashMap<String, HashMap<Double, Double>>());
            }

            //Enumerasi dari seluruh instance pada Instances masukan
            Enumeration<Instance> enumInst = m_Instances.enumerateInstances();
            //Looping untuk seluruh instance
            while (enumInst.hasMoreElements()) {
                //Mengambil Instance selanjutnya
                Instance tempInst = enumInst.nextElement();
                //Nilai domain untuk atribut saat ini
                String nilaiDomain = tempInst.stringValue(temp);
                //Class dari instance ini
                double classAttr = tempInst.classValue();

                if (distribution.get(attrName).get(nilaiDomain) == null) {
                    //Membuat hashmap baru jika domainNilai pertama kali muncul
                    distribution.get(attrName).put(nilaiDomain, new HashMap<Double, Double>());
                }
                if (distribution.get(attrName).get(nilaiDomain).get(classAttr) == null) {
                    //Membuat hashmap baru jika untuk pasangan domain nilai dan 
                    //kelas ini baru pertama kali muncul
                    for (int i = 0; i < m_Instances.numClasses(); i++) {
                        distribution.get(attrName).get(nilaiDomain).put(i + 0.0, 0.0);

                    }

                }
                //Menambahkan frekuensi kemunculan +1
                distribution.get(attrName).get(nilaiDomain).put(classAttr,
                        distribution.get(attrName).get(nilaiDomain).get(classAttr)
                                + (1.0 / classCount.get(classAttr)));
            }

            attrIndex++;

        }
        System.out.println(distribution.toString());
        System.out.println(classCount.toString());
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        int jumlahKelas = instance.classAttribute().numValues();
        double[] classifyResult = new double[jumlahKelas];

        //iterasi menghitung probabilitas untuk seluruh kelas
        for (int i = 0; i < jumlahKelas; i++) {

            //Rumus probabilitas Naive Bayes here

            classifyResult[i] = (double) classCount.get(i + 0.0) / numInstance;

            Enumeration<Attribute> enumAttr = instance.enumerateAttributes();

            while (enumAttr.hasMoreElements()) {
                Attribute temp = enumAttr.nextElement();

                if (!instance.isMissing(temp)) {

                    try {
                        classifyResult[i] = classifyResult[i]
                                * distribution.get(temp.name()).get(instance.stringValue(temp)).get(i + 0.0);

                    } catch (NullPointerException e) {
                        classifyResult[i] = 0;
                    }

                }

            }

        }
        double maxValue = 0;
        int currentIndex = 0;
        for (int i = 0; i < jumlahKelas; i++) {
            if (maxValue < classifyResult[i]) {
                currentIndex = i;
                maxValue = classifyResult[i];
            }
        }
        return currentIndex;

    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        int jumlahKelas = instance.classAttribute().numValues();

        double[] classifyResult = new double[jumlahKelas];

        //iterasi menghitung probabilitas untuk seluruh kelas
        for (int i = 0; i < jumlahKelas; i++) {

            //Rumus probabilitas Naive Bayes here

            classifyResult[i] = (double) classCount.get(i + 0.0) / numInstance;

            Enumeration<Attribute> enumAttr = instance.enumerateAttributes();

            while (enumAttr.hasMoreElements()) {
                Attribute temp = enumAttr.nextElement();

                if (!instance.isMissing(temp)) {

                    try {
                        classifyResult[i] = classifyResult[i]
                                * distribution.get(temp.name()).get(instance.stringValue(temp)).get(i + 0.0);

                    } catch (NullPointerException e) {

                    }

                }

            }

        }

        return classifyResult;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        // class
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);

        // instances
        result.setMinimumNumberInstances(0);

        return result;
    }

    /**
     * @param args the command line arguments
     */
    //   
    //    public static void main(String[] args) throws Exception {
    //       
    //        Instances data=TucilWeka.readDataSet("C:/Users/ASUS/Documents/mush(1).arff");
    // 
    //        Instances test=new Instances(data);
    //        
    ////        Classifier cls=new NBTubesAI();
    ////        cls.buildClassifier(data);
    //        int fold=10;
    //        //TucilWeka.saveModel(data);
    //        Classifier clsRead=(Classifier) TucilWeka.readModel();
    //        
    //        Evaluation eval=new Evaluation(data);
    //        eval.evaluateModel(clsRead, data);
    //
    //       // eval.crossValidateModel(clsRead, data, fold, new Random(1));
    //        System.out.println(eval.toSummaryString());
    //        System.out.println(eval.toMatrixString());
    //        System.out.println(eval.toClassDetailsString());
    //    }

}