lirmm.inria.fr.data.DataMatrix.java Source code

Java tutorial

Introduction

Here is the source code for lirmm.inria.fr.data.DataMatrix.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package lirmm.inria.fr.data;

import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.StringTokenizer;
import lirmm.inria.fr.math.BigSparseRealMatrix;
import lirmm.inria.fr.math.OpenLongToDoubleHashMap;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;

/**
 *
 * @author rbouadjenek
 */
public final class DataMatrix extends BigSparseRealMatrix {

    /**
     * Mapping for rows.
     */
    private final Map<String, Integer> rowsMapping;
    /**
     * Mapping for columns.
     */
    private final Map<String, Integer> columnsMapping;
    /**
     * Mapping for columns.
     */
    private final Map<Integer, Double> columnMeans;

    /**
     * Storage of the test dataset elements.
     */
    private final OpenLongToDoubleHashMap testSetEntries;

    /**
     * Storage of the validation dataset elements.
     */
    private final OpenLongToDoubleHashMap validationSetEntries;

    /**
     * Number of non zero entries in rows of the matrix.
     */
    private final OpenLongToDoubleHashMap rowNonZeroEntries = new OpenLongToDoubleHashMap(0.0);

    /**
     * Number of non zero entries in rows of the matrix.
     */
    private final OpenLongToDoubleHashMap columnNonZeroEntries = new OpenLongToDoubleHashMap(0.0);

    /**
     * Maximum value in the matrix
     */
    private double max;

    private DataMatrix(String file, int rowDimension, int columnDimension, double max)
            throws NotStrictlyPositiveException, NumberIsTooLargeException {
        //        super(5573, 100);
        super(rowDimension, columnDimension);
        this.rowsMapping = new HashMap<>();
        this.columnsMapping = new HashMap<>();
        this.columnMeans = new HashMap<>();
        this.testSetEntries = new OpenLongToDoubleHashMap(0.0);
        this.validationSetEntries = new OpenLongToDoubleHashMap(0.0);
        this.max = max;
        loadMatrix(file, rowDimension, columnDimension);
        //        normalize();
    }

    protected void loadMatrix(String file, int rowDimension, int columnDimension) {
        FileInputStream fstream;
        try {
            fstream = new FileInputStream(file);
            // Get the object of DataInputStream
            DataInputStream in = new DataInputStream(fstream);

            Set<Integer> cI = new HashSet();
            for (int i = 0; i < rowDimension; i++) {
                cI.add(i);
            }
            Set<Integer> cJ = new HashSet();
            for (int j = 0; j < columnDimension; j++) {
                cJ.add(j);
            }
            try (BufferedReader br = new BufferedReader(new InputStreamReader(in))) {
                String str;
                int z = 0;
                while ((str = br.readLine()) != null) {
                    str = str.trim();
                    if (str.startsWith("#")) {
                        continue;
                    }
                    if (str.trim().length() == 0) {
                        continue;
                    }
                    z++;
                    StringTokenizer st = new StringTokenizer(str);
                    String row_id = st.nextToken();
                    String column_id = st.nextToken();
                    double rating = Double.parseDouble(st.nextToken());
                    Integer i, j;
                    if (rowsMapping.containsKey(row_id)) {
                        i = rowsMapping.get(row_id);
                    } else {
                        i = cI.iterator().next();
                        //                        i = cI.get(0);
                        cI.remove(i);
                        rowsMapping.put(row_id, i);
                    }
                    if (columnsMapping.containsKey(column_id)) {
                        j = columnsMapping.get(column_id);
                    } else {
                        j = cJ.iterator().next();
                        //                        j = cJ.get(0);
                        cJ.remove(j);
                        columnsMapping.put(column_id, j);
                    }
                    setEntry(i, j, rating);
                }
            }

            for (OpenLongToDoubleHashMap.Iterator iterator = getEntries().iterator(); iterator.hasNext();) {
                iterator.advance();
                final long key = iterator.key();
                final int i, j;
                if (isTransposed()) {
                    j = (int) (key / getRowDimension());
                    i = (int) (key % getRowDimension());
                } else {
                    i = (int) (key / getColumnDimension());
                    j = (int) (key % getColumnDimension());
                }
                double v = rowNonZeroEntries.get(i);
                rowNonZeroEntries.put(i, v + 1);
                v = columnNonZeroEntries.get(j);
                columnNonZeroEntries.put(j, v + 1);
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void normalize() {
        int[] nbrElements;
        Map<Long, Double> entry = new HashMap();
        if (isTransposed()) {
            nbrElements = new int[getRowDimension()];
        } else {
            nbrElements = new int[getColumnDimension()];
        }
        for (OpenLongToDoubleHashMap.Iterator iterator = getEntries().iterator(); iterator.hasNext();) {
            iterator.advance();
            final double value = iterator.value();
            final long key = iterator.key();
            final int j;
            if (isTransposed()) {
                j = (int) (key / getRowDimension());
            } else {
                j = (int) (key % getColumnDimension());
            }
            nbrElements[j]++;
            if (columnMeans.containsKey(j)) {
                columnMeans.put(j, columnMeans.get(j) + value);
            } else {
                columnMeans.put(j, value);
            }
            entry.put(key, value);
        }
        for (Map.Entry<Integer, Double> e : columnMeans.entrySet()) {
            columnMeans.put(e.getKey(), e.getValue() / nbrElements[e.getKey()]);
        }
        for (Map.Entry<Long, Double> e : entry.entrySet()) {
            final long key = e.getKey();
            final double value = e.getValue();
            final int i, j;
            if (isTransposed()) {
                j = (int) (key / getRowDimension());
                i = (int) (key % getRowDimension());
            } else {
                i = (int) (key / getColumnDimension());
                j = (int) (key % getColumnDimension());
            }
            setEntry(i, j, value - columnMeans.get(j));
        }
    }

    public void cutDataSet(double rateTestSet, double rateValidationSet) {
        if (rateTestSet + rateValidationSet > 100) {
            System.err.println("Invalid rates:" + (rateTestSet + rateValidationSet) + " > 100%");
        } else {
            System.err.println("******************************");
            System.err.println("Initiale dataset size= " + getDataSize());
            System.err.println("******************************");
            int valTestSet = (int) (getDataSize() * rateTestSet / 100);// number of values to remove from the data entry, and put in the test set
            int valValidationSet = (int) (getDataSize() * rateValidationSet / 100);// number of values to remove from the data entry, and put in the validation set
            while (valTestSet > 0) {
                valTestSet--;
                long key = getEntries().getRandomKey();
                double value = getEntries().remove(key);
                testSetEntries.put(key, value);
                final int i, j;
                if (isTransposed()) {
                    j = (int) (key / getRowDimension());
                    i = (int) (key % getRowDimension());
                } else {
                    i = (int) (key / getColumnDimension());
                    j = (int) (key % getColumnDimension());
                }
                double v = rowNonZeroEntries.get(i);
                rowNonZeroEntries.put(i, v - 1);
                v = columnNonZeroEntries.get(j);
                columnNonZeroEntries.put(j, v - 1);

            }
            while (valValidationSet > 0) {
                valValidationSet--;
                long key = getEntries().getRandomKey();
                double value = getEntries().remove(key);
                validationSetEntries.put(key, value);
                final int i, j;
                if (isTransposed()) {
                    j = (int) (key / getRowDimension());
                    i = (int) (key % getRowDimension());
                } else {
                    i = (int) (key / getColumnDimension());
                    j = (int) (key % getColumnDimension());
                }
                double v = rowNonZeroEntries.get(i);
                rowNonZeroEntries.put(i, v - 1);
                v = columnNonZeroEntries.get(j);
                columnNonZeroEntries.put(j, v - 1);
            }
            System.err.println("Training dataset size= " + getDataSize());
            System.err.println("Test dataset size= " + getTestSetSize());
            System.err.println("Validation dataset size= " + getValidationSetSize());
            System.err.println("******************************");
        }
    }

    public static DataMatrix createDataMatrix(String file) throws FileNotFoundException, IOException {
        FileInputStream fstream;
        Set<String> listI = new HashSet<>();
        Set<String> listJ = new HashSet();
        fstream = new FileInputStream(file);
        double max = 0;
        // Get the object of DataInputStream
        DataInputStream in = new DataInputStream(fstream);
        BufferedReader br = new BufferedReader(new InputStreamReader(in));
        String str;
        while ((str = br.readLine()) != null) {
            str = str.trim();
            if (str.startsWith("#")) {
                continue;
            }
            if (str.trim().length() == 0) {
                continue;
            }
            StringTokenizer st = new StringTokenizer(str);
            String i = st.nextToken();
            String j = st.nextToken();
            double rating = Double.parseDouble(st.nextToken());
            if (max < rating) {
                max = rating;
            }
            if (!listI.contains(i)) {
                listI.add(i);
            }
            if (!listJ.contains(j)) {
                listJ.add(j);
            }
        }
        return new DataMatrix(file, listI.size(), listJ.size(), max);
    }

    public OpenLongToDoubleHashMap getTestSetEntries() {
        return testSetEntries;
    }

    public int getTestSetSize() {
        return testSetEntries.size();
    }

    public OpenLongToDoubleHashMap getValidationDataEntries() {
        return validationSetEntries;
    }

    public int getValidationSetSize() {
        return validationSetEntries.size();
    }

    public Map<String, Integer> getRowsMapping() {
        return rowsMapping;
    }

    public Map<String, Integer> getColumnsMapping() {
        return columnsMapping;
    }

    public double getEntry(String indexI, String indexJ) {
        int i = rowsMapping.get(indexI);
        int j = columnsMapping.get(indexJ);
        return getEntry(i, j);
    }

    public int getRowNonZeroEntry(int i) {
        if (isTransposed()) {
            return (int) columnNonZeroEntries.get(i);
        } else {
            return (int) rowNonZeroEntries.get(i);
        }
    }

    public int getColumnNonZeroEntry(int j) {
        if (isTransposed()) {
            return (int) rowNonZeroEntries.get(j);
        } else {
            return (int) columnNonZeroEntries.get(j);
        }
    }

    public double getMax() {
        return max;
    }

}