Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package malware_classification; import libsvm.*; import java.util.logging.*; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.io.PrintWriter; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Set; import org.apache.commons.math3.analysis.interpolation.LinearInterpolator; import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction; /** * * @author bspar */ public class Malware_Classification { // [i][j] is timepoint i, attribute j // first element is 1 for malicious, 0 for benign private double[][] all_data; private final String timestamp_str = "Time [ms]"; //first table private final String timestamp_str_2 = "Timestamp [ms]"; //second table private final int benign_class = 0; private final int malicious_class = 1; private int num_files; //number of files that have been read private HashMap<Integer, String> index_to_app; private PrintWriter writer; private static final Logger logger = Logger.getLogger(Malware_Classification.class.getName()); // first table column headings private final String[] valid_col_names = { "CPU1 Load [%]", "Memory Usage", "CPU2 Load [%]", "CPU3 Load [%]", "CPU4 Load [%]", "GPU Load [%]", "CPU Load [%]" }; // second table column headings private final String[] valid_col_names_2 = { "CPU Load [%]", "Mobile Bytes Sent [bytes]", "Mobile Bytes Received [bytes]", "Other Bytes Sent [bytes]", "Other Bytes Received [bytes]" }; // got 75.9% accuracy with bin size 1000 // 73% bin size 5000 // 72.5% bin size 10000 private int bin_size = 1000; //bin size in ms /** * @param args the command line arguments. Order is malicious_filename, * benign filename, (optional) bin_size */ public static void main(String[] args) { String malicious_file_path = args[0]; String benign_file_path = args[1]; int curr_bin_size; if (args.length > 2) { curr_bin_size = Integer.parseInt(args[2]); } else { curr_bin_size = -1; } String pid_str = ManagementFactory.getRuntimeMXBean().getName(); logger.setLevel(Level.CONFIG); logger.log(Level.INFO, pid_str); boolean found_file = false; String output_base = "std_output"; File output_file = null; for (int i = 0; !found_file; i++) { output_file = new File(output_base + i + ".txt"); found_file = !output_file.exists(); } FileHandler fh = null; try { fh = new FileHandler(output_file.getAbsolutePath()); } catch (IOException ex) { Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex); } catch (SecurityException ex) { Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex); } logger.addHandler(fh); logger.info("Writing output in " + output_file.getAbsolutePath()); Malware_Classification classifier = new Malware_Classification(malicious_file_path, benign_file_path, curr_bin_size); // classifier.run_tests(); } public Malware_Classification(String malicious, String benign, int given_bin_size) { String malicious_path = malicious; String benign_path = benign; num_files = 0; if (given_bin_size > 0) { this.bin_size = given_bin_size; } index_to_app = new HashMap<>(); ArrayList<double[]> malicious_data = read_data_all(malicious_path); ArrayList<double[]> benign_data = read_data_all(benign_path); logger.log(Level.INFO, "Number of malicious: " + malicious_data.size() + " benign: " + benign_data.size()); logger.info("Bin size = " + this.bin_size); all_data = coalesce_data(malicious_data, benign_data); long start_time_nano = System.nanoTime(); double[][] scaled_data = svm_rescale(all_data); // double result = cross_validation(scaled_data, 10); double result = -1; cross_validate_by_app(scaled_data); long end_time_nano = System.nanoTime(); double elapsed_time_sec = (double) (end_time_nano - start_time_nano) / (1000 * 1000 * 1000); logger.log(Level.INFO, "Full result = " + result + " over " + all_data.length + " data points in " + elapsed_time_sec + " seconds"); } /* Performs cross validation by training on all data from all files except one and trying to classify that one file (like LOOCV but per file, not vector). Assumes all_data is in the same format as that given to cross_validation, so all_data[i][j] = vector i, feature j, where j=0 is the class and j=1 is a unique number associated with each app. */ private void cross_validate_by_app(double[][] all_data) { // Map file numbers to all vectors from that file writer = make_print_writer(); writer.println("App Name,Class,File Number,Proportion Correct"); HashMap<Integer, HashSet<Integer>> file_to_indices = new HashMap<>(); for (int i = 0; i < all_data.length; i++) { int curr_file_num = (int) (0.5 + all_data[i][1]);//round to nearest int if (!file_to_indices.containsKey(curr_file_num)) { file_to_indices.put(curr_file_num, new HashSet<Integer>()); } file_to_indices.get(curr_file_num).add(i); } for (int file_num : file_to_indices.keySet()) { TestResults test_results = new TestResults(); TestResults train_results = new TestResults(); double[][] train_data = get_elements_not_excluded(all_data, file_to_indices.get(file_num)); double[][] test_data = get_elements(all_data, file_to_indices.get(file_num)); logger.info("Beginning cross validation on file " + index_to_app.get(file_num)); int true_class = (int) (test_data[0][0] + 0.5); svm_model model = svm_train(train_data); svm_evaluate(test_data, model, test_results); svm_evaluate(train_data, model, train_results); double test_prop_correct = -1; if (true_class == benign_class) { test_prop_correct = 1 - (double) test_results.wrong_benign / test_results.total_benign; if (test_results.total_malicious != 0 || test_results.total_benign != test_data.length) { logger.warning("Test app data are not all of one class."); } } else { test_prop_correct = 1 - (double) test_results.wrong_malicious / test_results.total_malicious; if (test_results.total_benign != 0 || test_results.total_malicious != test_data.length) { logger.warning("Test app data are not all of one class."); } } if (test_prop_correct < 0) logger.warning("Bad proportion correct"); String classification = (true_class == benign_class) ? "benign" : "malicious"; logger.info("Proportion correct = " + test_prop_correct + " for class " + classification); writer.println("" + index_to_app.get(file_num) + "," + classification + "," + file_num + "," + test_prop_correct); } writer.close(); } /* Returns a new PrintWriter to use to output all app data. Names the file app_output[num].txt, where [num] is one higher than the previously highest file. So if app_output1.txt already exists, this will return a printwriter to app_output2.txt */ private PrintWriter make_print_writer() { boolean found_file = false; String output_base = "app_output"; File output_file = null; for (int i = 0; !found_file; i++) { output_file = new File(output_base + i + ".txt"); found_file = !output_file.exists(); } PrintWriter pw = null; try { pw = new PrintWriter(output_file.getAbsolutePath()); } catch (IOException ex) { Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex); } catch (SecurityException ex) { Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex); } logger.info("Writing output app data to " + output_file); return pw; } /* Return a double[][] EXCEPT the indices specified in indices */ private double[][] get_elements_not_excluded(double[][] data, HashSet<Integer> indices) { double[][] new_data = new double[data.length - indices.size()][data[0].length]; int new_data_index = 0; for (int i = 0; i < data.length; i++) { if (!indices.contains(i)) { System.arraycopy(data[i], 0, new_data[new_data_index], 0, data[i].length); new_data_index++; } } return new_data; } /* Returns only the elements of data at indices specified in indices */ private double[][] get_elements(double[][] data, HashSet<Integer> indices) { double[][] new_data = new double[indices.size()][data[0].length]; int new_data_index = 0; for (int i = 0; i < data.length; i++) { if (indices.contains(i)) { System.arraycopy(data[i], 0, new_data[new_data_index], 0, data[i].length); new_data_index++; } } return new_data; } /* Performs k-way cross validation on all_data. Assumes all_data[i][j] is timepoint i, attribute j. j=0 is the class (1 or 0). */ private double cross_validation(double[][] all_data, int k) { ArrayList<Integer> all_indices = new ArrayList<>(); for (int i = 0; i < all_data.length; i++) { all_indices.add(i); } Collections.shuffle(all_indices); int entries_per_trial = all_data.length / k; TestResults test_results = new TestResults(); TestResults train_results = new TestResults(); for (int i = 0; i < k; i++) { logger.log(Level.INFO, "Beginning cross validation " + i + " of " + k); int min_index = entries_per_trial * i; int max_index = (i == k - 1) ? all_data.length : entries_per_trial * (i + 1); double[][] test_data = get_indices_from_array_inclusive(all_data, min_index, max_index, all_indices); double[][] train_data = get_indices_from_array_exclusive(all_data, min_index, max_index, all_indices); logger.log(Level.INFO, "Beginning training"); svm_model model = svm_train(train_data); logger.log(Level.INFO, "Beginning testing"); svm_evaluate(test_data, model, test_results); svm_evaluate(train_data, model, train_results); } logger.info("Test results (proportion incorrect):\n" + test_results.toString()); logger.info("Train results (proportion incorrect):\n" + train_results.toString()); double full_result = 1 - ((double) test_results.wrong_benign + test_results.wrong_malicious) / (test_results.total_benign + test_results.total_malicious); logger.info("Proportion CORRECT (both classes, training) = " + full_result); return full_result; } /* Returns a double[][] with all of the rows of all_data EXCEPT min_index through max_index, inclusive, exclusive resp. */ private double[][] get_indices_from_array_exclusive(double[][] all_data, int min_index, int max_index, ArrayList<Integer> indices) { int arr_length = all_data.length - (max_index - min_index); double[][] result = new double[arr_length][all_data[0].length]; for (int i = 0; i < min_index; i++) { for (int j = 0; j < all_data[i].length; j++) { result[i][j] = all_data[indices.get(i)][j]; } } for (int i = max_index + 1; i < all_data.length; i++) { for (int j = 0; j < all_data[i].length; j++) { result[i - (max_index - min_index)][j] = all_data[indices.get(i)][j]; } } return result; } /* Returns a new double[][] from all_data[min_index][] to all_data[max_index][], inclusive, exclusive respectively. */ private double[][] get_indices_from_array_inclusive(double[][] all_data, int min_index, int max_index, ArrayList<Integer> indices) { double[][] result = new double[max_index - min_index][all_data[0].length]; for (int i = min_index; i < max_index; i++) { for (int j = 0; j < all_data[i].length; j++) { result[i - min_index][j] = all_data[indices.get(i)][j]; } } return result; } /* Returns a single matrix with both malicious and benign data in it. First element is 1 for malicious, 0 for benign */ private double[][] coalesce_data(ArrayList<double[]> malicious, ArrayList<double[]> benign) { boolean is_valid = validate_data(malicious, benign); if (!is_valid) { logger.log(Level.WARNING, "Data has inconsistent sizes"); } int num_elems_total = malicious.size() + benign.size(); int num_cols = malicious.get(0).length; double[][] all_data = new double[num_elems_total][num_cols + 1]; int benign_size = benign.size(); for (int i = 0; i < benign_size; i++) { double[] timepoint = benign.get(i); all_data[i][0] = benign_class; for (int j = 1; j <= timepoint.length; j++) { all_data[i][j] = timepoint[j - 1]; } } for (int i = 0; i < malicious.size(); i++) { double[] timepoint = malicious.get(i); all_data[i + benign_size][0] = malicious_class; for (int j = 1; j <= timepoint.length; j++) { all_data[i + benign_size][j] = timepoint[j - 1]; } } return all_data; } /* Returns true if malicious and benign all have timepoints with the same correct_length, false otherwise */ private boolean validate_data(ArrayList<double[]> malicious, ArrayList<double[]> benign) { int correct_length = -1; if (malicious.size() > 0) { correct_length = malicious.get(0).length; } else if (benign.size() > 0) { correct_length = malicious.get(0).length; } else { return true; } for (int i = 0; i < malicious.size(); i++) { if (malicious.get(i).length != correct_length) { return false; } } for (int i = 0; i < benign.size(); i++) { if (benign.get(i).length != correct_length) { return false; } } return true; } /* Reads all data from all csv files in folder. */ private ArrayList<double[]> read_data_all(String folder_name) { File folder = new File(folder_name); File[] file_list = folder.listFiles(); ArrayList<double[]> all_data = new ArrayList<>(); for (File curr_file : file_list) { ArrayList<double[]> curr_data = read_data_file(curr_file.getAbsolutePath()); if (curr_data != null) { all_data.addAll(curr_data); this.num_files++; } } return all_data; } /* Reads the data in csv file called filename, and returns it in a matrix. data[i][j] is assumed to be timepoint i, dimension j. */ public ArrayList<double[]> read_data_file(String csvFile) { BufferedReader br = null; String line = ""; String csvSplitBy = ","; ArrayList<double[]> data = new ArrayList<>(); ArrayList<double[]> timestamps = new ArrayList<>(); ArrayList<double[]> binned_result = new ArrayList<>(); ArrayList<double[]> binned_result_all = null; logger.log(Level.FINE, "Reading file {0}", csvFile); try { br = new BufferedReader(new FileReader(csvFile)); line = br.readLine(); line = br.readLine(); if (line == null) { return data; } String app_name = line.split(csvSplitBy)[0]; index_to_app.put(num_files, app_name); line = br.readLine(); line = br.readLine(); if (line == null) { return data; } String[] col_headings = line.split(csvSplitBy); int correct_num_col_headings = col_headings.length - 1; //-1 for "description" int[] valid_cols = cols_with_data(col_headings, this.valid_col_names); int[] col_timestamps = cols_with_timestamp(valid_cols, col_headings, timestamp_str); int num_cols = valid_cols.length; // Read in data from the first table while ((line = br.readLine()) != null) { // use comma as separator String[] timepoint = line.split(csvSplitBy); if (timepoint.length != correct_num_col_headings) { break; } if (!is_valid_timepoint(timepoint, valid_cols)) { continue; } double[] new_row = new double[num_cols]; double[] new_timestamp = new double[num_cols]; for (int i = 0; i < num_cols; i++) { try { new_row[i] = Integer.parseInt(timepoint[valid_cols[i]]); new_timestamp[i] = Integer.parseInt(timepoint[col_timestamps[i]]); } catch (Exception e) { e.printStackTrace(); } } data.add(new_row); timestamps.add(new_timestamp); } binned_result = bin_data(data, timestamps); // Find the second table with the app's name ArrayList<double[]> second_row_data = new ArrayList<>(); ArrayList<double[]> second_row_timestamp = new ArrayList<>(); boolean has_found_table = false; int[] second_row_data_cols = new int[1]; int[] second_row_timestamp_cols = new int[1]; while ((line = br.readLine()) != null) { if (line.equals(app_name)) { has_found_table = true; line = br.readLine(); col_headings = line.split(csvSplitBy); second_row_data_cols = cols_with_data(col_headings, valid_col_names_2); second_row_timestamp_cols = cols_with_timestamp(second_row_data_cols, col_headings, timestamp_str_2); correct_num_col_headings = col_headings.length; line = br.readLine(); } if (has_found_table) { double[] new_row_data = new double[second_row_data_cols.length]; // 5 column double[] new_row_timestamp = new double[second_row_data_cols.length]; String[] curr_data = line.split(csvSplitBy); // check for the end of the table if (curr_data.length != correct_num_col_headings) { break; } for (int i = 0; i < new_row_data.length; i++) { if (second_row_data_cols[i] >= curr_data.length) { logger.warning("big problem"); } new_row_data[i] = Double.parseDouble(curr_data[second_row_data_cols[i]]); new_row_timestamp[i] = Double.parseDouble(curr_data[second_row_timestamp_cols[i]]); } second_row_data.add(new_row_data); second_row_timestamp.add(new_row_timestamp); } } if (!has_found_table) { logger.warning("Could not find the second table for " + app_name); return null; } ArrayList<double[]> data_2_interp = interpolate_data(second_row_data, second_row_timestamp, bin_size, binned_result.size()); binned_result_all = combine_data(binned_result, data_2_interp); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } finally { if (br != null) { try { br.close(); } catch (IOException e) { e.printStackTrace(); } } } return binned_result_all; } /* Combines data_1 and data_2 elementwise (ie concatenates each array). So result.get(i) = data_1.get(i) concatenated with data_2.get(i). Also adds num_files (the instance variable) as the first element in the array to be returned. */ private ArrayList<double[]> combine_data(ArrayList<double[]> data_1, ArrayList<double[]> data_2) { ArrayList<double[]> result = new ArrayList<>(); if (data_1.size() != data_2.size()) { logger.warning("Size of data from table 1 and table 2 are not equal"); } for (int i = 0; i < data_1.size(); i++) { double[] elem_1 = data_1.get(i); double[] elem_2 = data_2.get(i); double[] new_row = new double[elem_1.length + elem_2.length + 1]; new_row[0] = num_files; System.arraycopy(elem_1, 0, new_row, 1, elem_1.length); System.arraycopy(elem_2, 0, new_row, elem_1.length + 1, elem_2.length); result.add(new_row); } return result; } /* Returns the interpolated data in data_orig according to bin_size. data_orig.get(i)[j] is timepoint i, column j. timestamp_orig.get(i)[j] is timestamp i, column j. bin_size is the size of each bin, in ms. */ private ArrayList<double[]> interpolate_data(ArrayList<double[]> data_orig, ArrayList<double[]> timestamp_orig, int bin_size, int num_bins) { ArrayList<double[]> interp_data = new ArrayList<>(); int num_points_orig = data_orig.size(); // TODO: change some of the num_points_orig to num_bins double[] x = new double[num_points_orig]; double[] y = new double[num_points_orig]; int num_cols = data_orig.get(0).length; for (int i = 0; i < num_bins; i++) { interp_data.add(new double[num_cols]); } for (int col = 0; col < num_cols; col++) { // To use LinearInterpolator, first need arrays for (int j = 0; j < num_points_orig; j++) { x[j] = timestamp_orig.get(j)[col]; y[j] = data_orig.get(j)[col]; } LinearInterpolator lin_interp = new LinearInterpolator(); PolynomialSplineFunction interp_func = lin_interp.interpolate(x, y); for (int j = 0; j < num_bins; j++) { double curr_bin = bin_size * j; double[] knots = interp_func.getKnots(); // logger.info() if (interp_func.isValidPoint(curr_bin)) { interp_data.get(j)[col] = interp_func.value(curr_bin); } else if (knots[0] > curr_bin) //bin is too small { interp_data.get(j)[col] = y[0]; } else if (knots[knots.length - 1] < curr_bin) // bin is larger than data { interp_data.get(j)[col] = y[y.length - 1]; } else { logger.warning("Cannot interpolate at bin starting at " + curr_bin); } } } return interp_data; } /* Returns a new arraylist with the binned data. data.get(i)[j] is timepoint i, column j. timestamp.get(i)[j] is timestamp corresponding to datapoint data.get(i)[j] */ private ArrayList<double[]> bin_data(ArrayList<double[]> data, ArrayList<double[]> timestamp) { double max_timestamp = timestamp.get(timestamp.size() - 1)[0]; int num_bins = (int) max_timestamp / bin_size + 1; ArrayList<double[]> binned_data = new ArrayList<>(); for (int i = 0; i < num_bins; i++) { double[] new_data_row = new double[data.get(0).length]; for (int col = 0; col < data.get(0).length; col++) { ArrayList<Integer> indices = find_indices_in_bin(data, timestamp, bin_size, i, col); double col_average = 0; for (int index : indices) { col_average += data.get(index)[col]; } //TODO: do interpolation if there are no data points in the bin col_average /= indices.size(); new_data_row[col] = col_average; } binned_data.add(new_data_row); } return binned_data; } /* Returns the indices in data that are in bin bin_num, where each bin is of size bin_size. col_num is the number of the column to bin. bin_num starts at 0. */ private ArrayList<Integer> find_indices_in_bin(ArrayList<double[]> data, ArrayList<double[]> data_timestamps, int bin_size, int bin_num, int col_num) { // Linear search over the array. // TODO: make this binary search ArrayList<Integer> indices = new ArrayList<>(); int bin_min = bin_num * bin_size; int bin_max = (bin_num + 1) * bin_size; for (int i = 0; i < data_timestamps.size(); i++) { double curr_timestamp = data_timestamps.get(i)[col_num]; if (curr_timestamp < bin_max && curr_timestamp >= bin_min) { indices.add(i); } } return indices; } /* @return Returns the timestamps corresponding to the column indices in data_col_indicesd. Assumes that the timestamp column is the first column with heading "time" that is to the left (lower index) that the column heading of interest */ private int[] cols_with_timestamp(int[] data_col_indices, String[] col_headings, String timestamp_heading) { int[] timestamp_col_indices = new int[data_col_indices.length]; boolean found_index = false; for (int data_index = 0; data_index < data_col_indices.length; data_index++) { found_index = false; for (int curr_index = data_col_indices[data_index]; curr_index >= 0; curr_index--) { if (col_headings[curr_index].contains(timestamp_heading)) { timestamp_col_indices[data_index] = curr_index; found_index = true; break; } } if (!found_index) { logger.warning("Cannot find index of timestamp for datapoint " + col_headings[data_index]); } } return timestamp_col_indices; } /* Returns true if every field of "timepoint" is not the empty string, false otherwise (only looks at valid_cols all_indices) */ private boolean is_valid_timepoint(String[] timepoint, int[] valid_cols) { String empty_string = ""; for (int i = 0; i < valid_cols.length; i++) { if (timepoint[valid_cols[i]].equals(empty_string)) { return false; } } return true; } /* Returns an array of indices of colums in col_names that valid_col_names */ private int[] cols_with_data(String[] all_col_names, String[] valid_col_names) { ArrayList<Integer> result = new ArrayList<>(); int[] missing_names = new int[valid_col_names.length]; for (int i = 0; i < all_col_names.length; i++) { for (int j = 0; j < valid_col_names.length; j++) { if (all_col_names[i].contains(valid_col_names[j])) { missing_names[j] = 1; result.add(i); break; } } // logger.log(Level.INFO, col_names[i]); } for (int i = 0; i < missing_names.length; i++) { if (missing_names[i] == 0) { logger.log(Level.INFO, "Missing name " + valid_col_names[i]); } } int[] result_array = new int[result.size()]; for (int i = 0; i < result.size(); i++) { result_array[i] = result.get(i); } return result_array; } /* Runs a few unit tests on the svm package. */ public void run_tests() { logger.log(Level.INFO, "Beginning (simple, sanity-check) " + "classification"); int num_dim = 2; int num_test_examples = 100; int num_train_examples = 1000; double[][] train_data = generate_data(num_train_examples, num_dim); double[][] test_data = generate_data(num_test_examples, num_dim); svm_model model = svm_train(train_data); double result = 0; for (int i = 0; i < num_test_examples; i++) { result += svm_classify(test_data[i], model); } result /= num_test_examples; logger.log(Level.INFO, "Result = " + result); } /* Returns a new num_points_orig x num_dim array (each data point is a vector of correct_length num_dim). The first element is always the class (1 or 0). */ private double[][] generate_data(int num_points, int num_dim) { double[][] data = new double[num_points][num_dim + 1]; for (int curr_point = 0; curr_point < num_points; curr_point++) { int curr_class = (int) Math.round(Math.random()); // 0 or 1 data[curr_point][0] = curr_class; int mult_factor = curr_class * 2 - 1; //-1 or 1 for (int curr_dim = 1; curr_dim < num_dim + 1; curr_dim++) { data[curr_point][curr_dim] = Math.random() * mult_factor; } } return data; } private double[][] generate_data_circle(int num_points, int num_dim) { double[][] data = new double[num_points][num_dim + 1]; // Class 0 is within the hypersphere of radius 0.5 centered at (0,0). for (int curr_point = 0; curr_point < num_points; curr_point++) { double radius = 0; for (int curr_dim = 1; curr_dim < num_dim + 1; curr_dim++) { double new_point = Math.random() * 2 - 1; radius += new_point * new_point; data[curr_point][curr_dim] = new_point; } radius = Math.sqrt(radius); data[curr_point][0] = (radius < 0.5) ? 0 : 1; } return data; } private svm_model svm_train(double[][] train) { svm_problem prob = new svm_problem(); int dataCount = train.length; prob.y = new double[dataCount]; prob.l = dataCount; prob.x = new svm_node[dataCount][]; for (int i = 0; i < dataCount; i++) { double[] features = train[i]; prob.x[i] = new svm_node[features.length - 2]; for (int j = 2; j < features.length; j++) { svm_node node = new svm_node(); node.index = j; node.value = features[j]; prob.x[i][j - 2] = node; } prob.y[i] = features[0]; } svm_parameter param = new svm_parameter(); param.probability = 1; param.gamma = 0.5; param.nu = 0.5; param.C = 1; param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.LINEAR; param.cache_size = 20000; param.eps = 0.001; svm_model model = svm.svm_train(prob, param); return model; } /* Classifies each row in test_data on model, and returns the proportion correct. Assumes test_data[i][0] is 0 for benign, 1 for malicious for all i. Fills results with the false positive and false negative rate, respectively. results[0] = number of results which were classified 1 but were actually 0, results[1] = number of results which were classified 0 but were actually 1 Only adds to results[], does NOT zero it out first. results[2] is the number of class 0, results[3] is the number of class 1 */ public void svm_evaluate(double[][] test_data, svm_model model, TestResults results) { for (double[] curr_vector : test_data) { int is_correct = svm_classify(curr_vector, model); if (curr_vector[0] == 0) { results.total_benign++; if (is_correct == 0) { results.wrong_benign++; } } else { results.total_malicious++; if (is_correct == 0) { results.wrong_malicious++; } } } } // Returns 1 if classified correctly, 0 otherwise private int svm_classify(double[] features, svm_model model) { svm_node[] nodes = new svm_node[features.length - 1]; for (int i = 2; i < features.length; i++) { svm_node node = new svm_node(); node.index = i; node.value = features[i]; nodes[i - 2] = node; } int totalClasses = 2; int[] labels = new int[totalClasses]; svm.svm_get_labels(model, labels); double[] prob_estimates = new double[totalClasses]; double v = svm.svm_predict_probability(model, nodes, prob_estimates); // for (int i = 0; i < totalClasses; i++){ // System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); // } // System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")"); return (Math.round(v) == Math.round(features[0])) ? 1 : 0; } /* Returns data_orig rescaled so that every data element is in [-1,1]. does not touch the first element of each column (that's data_orig[i][0] for all i). Assumes data_orig[i][j] is vector i, feature j (feature 0 is the true class) */ private double[][] svm_rescale(double[][] data_orig) { int num_features = data_orig[0].length; double[][] data_scaled = new double[data_orig.length][num_features]; double[] max_vals = new double[num_features]; double[] min_vals = new double[num_features]; // find max and min vals of all features. Feature 0 is the true class. for (int curr_feature = 1; curr_feature < num_features; curr_feature++) { double max_val = Double.NEGATIVE_INFINITY; double min_val = Double.POSITIVE_INFINITY; for (int curr_vec = 0; curr_vec < data_orig.length; curr_vec++) { double curr_val = data_orig[curr_vec][curr_feature]; if (curr_val > max_val) { max_val = curr_val; } if (curr_val < min_val) { min_val = curr_val; } } max_vals[curr_feature] = max_val; min_vals[curr_feature] = min_val; if (Double.compare(max_vals[curr_feature], min_vals[curr_feature]) == 0) { logger.info("All data for feature " + curr_feature + " are the same"); } } //rescale for (int curr_vec = 0; curr_vec < data_orig.length; curr_vec++) { data_scaled[curr_vec][0] = data_orig[curr_vec][0]; data_scaled[curr_vec][1] = data_orig[curr_vec][1]; for (int curr_feature = 2; curr_feature < num_features; curr_feature++) { if (Double.compare(max_vals[curr_feature], min_vals[curr_feature]) != 0) { data_scaled[curr_vec][curr_feature] = (data_orig[curr_vec][curr_feature] - min_vals[curr_feature]) / (max_vals[curr_feature] - min_vals[curr_feature]); } else { data_scaled[curr_vec][curr_feature] = 0; } } } return data_scaled; } private class TestResults { public int total_benign, total_malicious; public int wrong_benign, wrong_malicious; public TestResults() { total_benign = 0; total_malicious = 0; wrong_benign = 0; wrong_malicious = 0; } /* Returns a string reporting the proportion of incorrect results for each class */ @Override public String toString() { double benign_wrong = (double) wrong_benign / total_benign; double malicious_wrong = (double) wrong_malicious / total_malicious; String result = "benign: " + benign_wrong + " = " + wrong_benign + " / " + total_benign + "\nmalicious: " + malicious_wrong + " = " + wrong_malicious + " / " + total_malicious; return result; } } }