com.mycompany.neuralnetwork.NeuralNetworkClassifier.java Source code

Java tutorial

Introduction

Here is the source code for com.mycompany.neuralnetwork.NeuralNetworkClassifier.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.mycompany.neuralnetwork;

import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/**
 *
 * @author Besseym
 */
public class NeuralNetworkClassifier extends Classifier {
    Network network;
    int layers;
    int iterations;
    double learningFactor;

    public NeuralNetworkClassifier(int layers, int iterations, double learningFactor) {
        this.layers = layers;
        this.iterations = iterations;
        this.learningFactor = learningFactor;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        int inputCount = instances.numAttributes() - 1;

        List<Integer> nodesPerLayer = new ArrayList<>();

        for (int i = 0; i < layers - 1; i++) {
            nodesPerLayer.add(inputCount);
        }

        nodesPerLayer.add(instances.numDistinctValues(instances.classIndex()));

        network = new Network(inputCount, nodesPerLayer);

        ArrayList<Double> errorsPerIteration = new ArrayList<>();
        for (int j = 0; j < iterations; j++) {
            double errorsPer = 0;
            for (int k = 0; k < instances.numInstances(); k++) {
                Instance instance = instances.instance(k);

                List<Double> input = new ArrayList<>();
                for (int i = 0; i < instance.numAttributes(); i++) {
                    if (Double.isNaN(instance.value(i)) && i != instance.classIndex())
                        input.add(0.0);
                    else if (i != instance.classIndex())
                        input.add(instance.value(i));
                }

                errorsPer += network.train(input, instance.value(instance.classIndex()), learningFactor);
            }

            errorsPerIteration.add(errorsPer);

        }

        //Display Errors This is used to collect the data for the graph 
        //for (Double d : errorsPerIteration) 
        //{
        //  System.out.println(d);
        //}
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {

        List<Double> input = new ArrayList<>();

        for (int i = 0; i < instance.numAttributes(); i++) {
            if (Double.isNaN(instance.value(i)) && i != instance.classIndex())
                input.add(0.0);

            else if (i != instance.classIndex())
                input.add(instance.value(i));
        }

        List<Double> outputs = network.getOutputs(input);

        double largeVal = -1;
        int index = 0;

        for (int i = 0; i < outputs.size(); i++) {
            double temp = outputs.get(i);

            if (temp > largeVal) {
                largeVal = temp;
                index = i;
            }
        }

        return index;
    }

}