binarizer.LayoutAnalysis.java Source code

Java tutorial

Introduction

Here is the source code for binarizer.LayoutAnalysis.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package binarizer;

import diuf.diva.dia.ms.ml.ae.SCAE;
//import diuf.diva.dia.ms.ml.ae.classifier.FFCNN;
import diuf.diva.dia.ms.util.DataBlock;
import diuf.diva.dia.ms.util.DataBlockDisplay;
import diuf.diva.dia.ms.util.Image;
import diuf.diva.dia.ms.util.ImageDataBlock;
import diuf.diva.dia.ms.util.Tracer;

import java.awt.Color;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import Hao.wp_NB_Disc_SFS;

/**
 *
 * @author ms
 */
public class LayoutAnalysis {

    String trainingImPath = null;
    String trainingGtPath = null;
    SCAE scae = null;
    Dataset dataset = null;
    int[] selectedFeatures = null;

    public static int[] colorList = { Color.RED.getRGB(), Color.GREEN.getRGB(), Color.BLUE.getRGB(),
            Color.BLACK.getRGB(), Color.WHITE.getRGB(), Color.GRAY.getRGB() };

    public void setup(String scaedat) throws ClassNotFoundException, IOException {
        trainingImPath = "data/parzival/training";
        trainingGtPath = "data/parzival-gt/training";
        scae = SCAE.load(scaedat);
        System.out.println("Number of features: " + scae.getFeatureLength());
        dataset = new Dataset(trainingImPath, trainingGtPath, 2); // max 2 images
    }

    public void getFeaturesAndLables(String txtFile) throws IOException {

        File fout = new File(txtFile);
        FileOutputStream fos = new FileOutputStream(fout);
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(fos));

        int fractionKept = 1234; // skip 1233 pixels, get features+label for one, skips 1233 ...

        int hSize = scae.getInputWidth() / 2 + 1;
        int vSize = scae.getInputHeight() / 2 + 1;

        for (GTImage gti : dataset) {
            int n = 0;
            for (int px = hSize; px < gti.pixels.getWidth() - hSize; px++) {
                for (int py = vSize; py < gti.pixels.getHeight() - vSize; py++) {
                    if ((++n) % fractionKept != 0) {
                        continue;
                    }

                    scae.setInput(gti.pixels, px - scae.getInputWidth() / 2, py - scae.getInputHeight() / 2);

                    scae.forward();

                    float[] features = scae.getFeatureVector();
                    int label = gti.classNum[px][py];
                    //      System.out.println(n);
                    String featuresAndLable = "";
                    for (int indexFeature = 0; indexFeature < features.length; indexFeature++) {
                        featuresAndLable += Float.toString(features[indexFeature]);
                        featuresAndLable += ",";
                    }
                    featuresAndLable += Integer.toString(label);
                    bw.write(featuresAndLable);
                    bw.newLine();
                }
            }
        }
        bw.close();
    }

    public void generateArff(String txtFile, String arffFile) throws IOException {

        File fout = new File(arffFile);
        FileOutputStream fos = new FileOutputStream(fout);
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(fos));
        String arffName = "pz_train";
        arffName = "@RELATION " + arffName;
        bw.write(arffName);
        bw.newLine();
        bw.newLine();

        for (int i = 1; i <= scae.getFeatureLength(); i++) {
            bw.write("@ATTRIBUTE feature" + i + " NUMERIC");
            bw.newLine();
        }
        bw.write("@ATTRIBUTE class {0,1,2,3,4}");
        bw.newLine();
        bw.newLine();
        bw.write("@DATA");
        bw.newLine();

        BufferedReader br = new BufferedReader(new FileReader(txtFile));
        try {
            String line = br.readLine();

            while (line != null) {
                bw.write(line);
                bw.newLine();
                line = br.readLine();
            }
        } finally {
            br.close();
        }
        bw.close();
    }

    public void featureSelection(String arffFile) throws Exception {
        wp_NB_Disc_SFS sfs = new wp_NB_Disc_SFS();
        selectedFeatures = sfs.run(arffFile, "null", "null", "null");
    }

    public void deleteFeatures(String saveScae) throws IOException {
        int nbSelectedFeatures = selectedFeatures.length - 1;
        int[] deletedFeatures = new int[scae.getFeatureLength() - nbSelectedFeatures];
        int indexDeletedFeature = 0;
        boolean selected = false;
        for (int i = 0; i < scae.getFeatureLength(); i++) {
            selected = false;
            for (int j = 0; j < nbSelectedFeatures; j++) {
                if (i == selectedFeatures[j]) {
                    selected = true;
                    break;
                }
            }
            if (selected == false) {
                deletedFeatures[indexDeletedFeature] = i;
                indexDeletedFeature++;
            }
        }
        scae.deleteFeatures(deletedFeatures);
        scae.save(saveScae);
    }

    public double crossValidation(String arffFile) throws Exception {
        DataSource source = new DataSource(arffFile);
        Instances trainingData = source.getDataSet();
        if (trainingData.classIndex() == -1)
            trainingData.setClassIndex(trainingData.numAttributes() - 1);
        NaiveBayes nb = new NaiveBayes();
        nb.setUseSupervisedDiscretization(true);
        Evaluation evaluation = new Evaluation(trainingData);
        evaluation.crossValidateModel(nb, trainingData, 10, new Random(1));
        System.out.println(evaluation.toSummaryString());
        return evaluation.errorRate();
    }

    /**
     * @param args the command line arguments
     * @throws Exception 
     */

    public static void main(String[] args) throws Exception {
        LayoutAnalysis la = new LayoutAnalysis();
        la.setup("scae.dat");
        la.getFeaturesAndLables("weka\\pz_train.txt");
        la.generateArff("weka\\pz_train.txt", "weka\\pz_train.arff");
        la.featureSelection("E:\\HisDoc project\\SCAE\\SCAE_v2\\gae\\weka\\pz_train.arff");
        la.deleteFeatures("scae_reduced.dat");
        //================
        la.setup("scae_reduced.dat");
        la.getFeaturesAndLables("weka\\pz_train_reduced.txt");
        la.generateArff("weka\\pz_train_reduced.txt", "weka\\pz_train_reduced.arff");
        la.crossValidation("weka\\pz_train_reduced.arff");
    }

}