Java tutorial
/** * Copyright 2009 Kevin J. Menard Jr. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github; import org.apache.log4j.Logger; import org.apache.commons.collections.CollectionUtils; import java.util.*; import java.util.concurrent.*; import java.io.IOException; public class NearestNeighbors { public final int TOP_COMMON_WATCHERS_COUNT = 10; public final int TOP_REPOS_COUNT = 100; public final int THREAD_POOL_SIZE = 5; private final Logger log = Logger.getLogger(NearestNeighbors.class); public final Map<String, Watcher> training_watchers; public final Map<String, Repository> training_repositories; public final Map<String, NeighborRegion> training_regions = new HashMap<String, NeighborRegion>(); public final Map<String, Set<NeighborRegion>> watchers_to_regions = new HashMap<String, Set<NeighborRegion>>(); public final Map<String, Set<Repository>> owners_to_repositories = new HashMap<String, Set<Repository>>(); public NearestNeighbors(final DataSet training_set) throws IOException { log.info("knn-init: Loading watchers and repositories."); training_watchers = training_set.getWatchers(); training_repositories = training_set.getRepositories(); log.info("knn-init: Building repository regions."); for (final Map.Entry<String, Repository> pair : training_repositories.entrySet()) { final Repository repo = pair.getValue(); final Repository root = Repository.findRoot(repo); final NeighborRegion existing_region = training_regions.get(root.id); if (existing_region == null) { training_regions.put(root.id, new NeighborRegion(repo)); } else { existing_region.add(repo); } // Store in inverted list structure from watcher ID to regions. for (final Watcher w : repo.watchers) { if (watchers_to_regions.get(w.id) == null) { watchers_to_regions.put(w.id, new HashSet<NeighborRegion>()); } watchers_to_regions.get(w.id).add(training_regions.get(root.id)); } // Store in inverted list structure from owenr to regions. if (owners_to_repositories.get(repo.owner) == null) { owners_to_repositories.put(repo.owner, new HashSet<Repository>()); } owners_to_repositories.get(repo.owner).add(repo); } } public Map<String, Map<String, Collection<Float>>> evaluate(final Collection<Watcher> test_instances) throws IOException, InterruptedException, ExecutionException { log.info("knn-evaluate: Loading watchers."); log.debug(String.format("knn-evaluate: Total unique test watchers: %d", test_instances.size())); final Map<String, Map<String, Collection<Float>>> results = new HashMap<String, Map<String, Collection<Float>>>(); final ExecutorService pool = Executors.newFixedThreadPool(THREAD_POOL_SIZE); // For each watcher in the test set . . . log.info("knn-evaluate: Starting evaluations"); int test_watcher_count = 0; for (final Watcher watcher : test_instances) { test_watcher_count++; log.info(String.format("Processing watcher (%d/%d)", test_watcher_count, test_instances.size())); results.put(watcher.id, new HashMap<String, Collection<Float>>()); // See if we have any training instances for the watcher. If not, we really can't guess anything. final Watcher training_watcher = training_watchers.get(watcher.id); if (training_watcher == null) { continue; } /*********************************** *** Handling repository regions *** ***********************************/ // Calculate the distance between the repository regions we know the test watcher is in, to every other // region in the training data. final Set<NeighborRegion> test_regions = watchers_to_regions.get(watcher.id); /* final List<NeighborRegion> related_regions = find_regions_with_most_cutpoints(watcher, test_regions); for (final NeighborRegion related_region : related_regions) { storeDistance(results, watcher, related_region.most_popular, 0.0f); storeDistance(results, watcher, related_region.most_forked, 0.0f); } */ /* also_owned_counts = {} training_watcher.repositories.each do |repo_id| repo = @training_repositories[repo_id] also_owned_counts[repo.owner] ||= 0 also_owned_counts[repo.owner] += 1 end also_owned_counts.each do |owner, count| # If 5% or more of the test watcher's repositories are owned by the same person, look at the owner's other repositories. if (also_owned_repos.size.to_f / training_watcher.repositories.size) > 0.05 || (also_owned_repos.size.to_f / @owners_to_repositories[owner].size) > 0.3 repositories_to_check.merge(@owners_to_repositories[owner].collect {|r| r.id}) end end */ // Add in the most forked regions from similar watchers. /* final Set<NeighborRegion> related_regions = find_regions_containing_fellow_watchers(test_regions); for (final NeighborRegion region : related_regions) { repositories_to_check.add(region.most_forked); } */ /************************************* **** Begin distance calculations **** *************************************/ int test_region_count = 0; for (final NeighborRegion test_region : test_regions) { test_region_count++; final CompletionService<Map<Repository, Float>> cs = new ExecutorCompletionService<Map<Repository, Float>>( pool); int training_region_count = 0; final Set<Repository> repositories_to_check = new HashSet<Repository>(); // Add in the most forked repositories from each region we know the test watcher is in. for (final NeighborRegion region : test_regions) { repositories_to_check.add(region.most_forked); } for (final Repository repo : training_watcher.repositories) { if (repo.parent != null) { repositories_to_check.add(repo.parent); } } /******************************************************************** *** Handling repositories owned by owners we're already watching *** ********************************************************************/ if (training_watcher.owner_counts.get(test_region.most_forked.owner) != null && (((training_watcher.owner_counts.get(test_region.most_forked.owner).floatValue() / owners_to_repositories.get(test_region.most_forked.owner).size()) > 0.25) || (training_watcher.owner_distribution(test_region.most_forked.owner) > 0.25))) { for (final Repository also_owned : owners_to_repositories.get(test_region.most_forked.owner)) { { // Only add repos that are the most forked in their respective regions. if (also_owned.region.most_forked.equals(also_owned)) { repositories_to_check.add(also_owned); } } } } for (final Repository training_repository : repositories_to_check) { training_region_count++; if (log.isDebugEnabled()) { log.debug(String.format("Processing watcher (%d/%d) - (%d/%d):(%d/%d)", test_watcher_count, test_instances.size(), test_region_count, test_regions.size(), training_region_count, repositories_to_check.size())); } // Submit distance calculation task if the test watcher isn't already watching the repository. cs.submit(new Callable<Map<Repository, Float>>() { public Map<Repository, Float> call() throws Exception { final Map<Repository, Float> ret = new HashMap<Repository, Float>(); if (!training_repository.watchers.contains(training_watcher)) { float distance = euclidian_distance(training_watcher, test_region.most_forked, training_repository); ret.put(training_repository, Float.valueOf(distance)); } return ret; } }); } // Process the distance calculation results. for (int i = 0; i < repositories_to_check.size(); i++) { final Map<Repository, Float> distance = cs.take().get(); for (final Map.Entry<Repository, Float> pair : distance.entrySet()) { storeDistance(results, watcher, pair.getKey(), pair.getValue().floatValue()); } } } } /* =begin # Find a set of repositories from fellow watchers that happen to watch a lot of same repositories as the test watcher. repositories_to_check.merge find_repositories_containing_fellow_watchers(test_regions) # Add in the most popular and most forked regions we know the test watcher is in. related_regions = find_regions_containing_fellow_watchers(test_regions) related_regions.each do |region| repositories_to_check << region.most_popular.id repositories_to_check << region.most_forked.id end $LOG.info "Added regions from fellow watchers for watcher #{watcher.id} -- new size #{repositories_to_check.size} (+ #{repositories_to_check.size - old_size})" old_size = repositories_to_check.size $LOG.info "Added similarly owned for watcher #{watcher.id} -- new size #{repositories_to_check.size} (+ #{repositories_to_check.size - old_size})" old_size = repositories_to_check.size =end =begin end results */ return results; } private Set<NeighborRegion> find_regions_containing_fellow_watchers(final Set<NeighborRegion> test_regions) { // Take a look at each region the test instance is in. // For each region, find the most common watchers. final Map<Watcher, Number> similar_watcher_counts = new HashMap<Watcher, Number>(); for (final NeighborRegion watched_region : test_regions) { for (final Watcher related_watcher : watched_region.watchers) { if (similar_watcher_counts.get(related_watcher) == null) { similar_watcher_counts.put(related_watcher, 0); } similar_watcher_counts.put(related_watcher, similar_watcher_counts.get(related_watcher).intValue() + 1); } } // Convert raw counts to ratios. for (final Map.Entry<Watcher, Number> pair : similar_watcher_counts.entrySet()) { similar_watcher_counts.put(pair.getKey(), pair.getValue().floatValue() / test_regions.size()); } // Collect the user IDs for the 10 most common watchers. final List<Map.Entry<Watcher, Number>> sorted = MyUtils.sortWatcherCounts(similar_watcher_counts, new NumberMeanComparator()); final List<Watcher> most_common_watchers = new ArrayList<Watcher>(); final int upperBound = sorted.size() < TOP_COMMON_WATCHERS_COUNT ? sorted.size() : TOP_COMMON_WATCHERS_COUNT; for (int i = 0; i < upperBound; i++) { most_common_watchers.add(sorted.get(i).getKey()); } // # Collect the user IDs for any user that appears in 50% or more of the watcher's repository regions. // #most_common_watchers = @similar_watcher_counts.find_all {|key, value| value >= TOP_COMMON_WATCHERS_PERCENT}.collect {|key, value| key} // Now go through each of those watchers and add in all the repository regions that they're watching, but // that the current watcher is not watching. final Set<NeighborRegion> ret = new HashSet<NeighborRegion>(); for (final Watcher common_watcher : most_common_watchers) { ret.addAll(watchers_to_regions.get(common_watcher.id)); if (ret.size() > TOP_REPOS_COUNT) { break; } } /* # Now sort the related regions by number of watchers and grab the 100 top ones. sorted_related_regions = related_regions.to_a.sort { |x, y| y.watchers.size <=> x.watchers.size } sorted_related_regions[0...TOP_COMMON_REPOS] */ return ret; } /** * Calculates Euclidian distance between two repositories. * * @param training_watcher * @param first * @param second * @return */ private float euclidian_distance(final Watcher training_watcher, final Repository first, final Repository second) { if (first.equals(second)) { return 1000.0f; } final Collection<Watcher> common_watchers = CollectionUtils.intersection(first.watchers, second.watchers); float distance = 1000.0f; // Set up weights. final float parent_child_weight = 0.4f; final float related_weight = 0.75f; final float same_owner_weight = 0.65f; final float common_watcher_weight = 1.0f; // Penalize repositories that have no parent nor children. final float lone_repo_weight = (first.parent == null && first.children.isEmpty()) || (second.parent == null && second.children.isEmpty()) ? 2.0f : 1.0f; final float leaf_repo_weight = second.children.isEmpty() ? 1.0f : 1.0f; /* // Figure out how many repositories the common watchers have in common with the test watcher. final Map<String, Number> similar_watcher_counts = new HashMap<String, Number>(); for (final NeighborRegion watched_region : watchers_to_regions.get(training_watcher.id)) { for (final Watcher w : watched_region.watchers) { if (similar_watcher_counts.get(w.id) == null) { similar_watcher_counts.put(w.id, new Integer(0)); } similar_watcher_counts.put(w.id, new Integer(similar_watcher_counts.get(w.id).intValue() + 1)); } } // Convert raw counts to ratios. for (final Map.Entry<String, Number> pair : similar_watcher_counts.entrySet()) { similar_watcher_counts.put(pair.getKey(), pair.getValue().floatValue() / watchers_to_regions.get(training_watcher.id).size()); } // Figure out how many repositories the test watcher watches that are owned by the first repository owner. float similarly_owned_count = 0.0f; for (final NeighborRegion watched_region : watchers_to_regions.get(training_watcher.id)) { if (watched_region.most_forked.owner == first.owner) { similarly_owned_count += 1.0f; } } // TODO: (KJM 08/29/09) Consider looking up the region, since that's what we use everywhere else. int total_watchers = 1; for (final Repository repo : owners_to_repositories.get(first.owner)) { total_watchers += repo.watchers.size(); } // If common_watchers is empty, just make the value 1.0 because we use the value as a divisor. float common_watchers_repo_diversity = common_watchers.isEmpty() ? 1.0f : 0.0f; for (final Watcher w : common_watchers) { common_watchers_repo_diversity += w.repositories.size(); } common_watchers_repo_diversity /= common_watchers.size(); */ final float first_common_watchers_ratio = ((float) common_watchers.size()) / first.watchers.size(); final float second_common_watchers_ratio = ((float) common_watchers.size()) / second.watchers.size(); if ((first.parent != null) && first.parent.equals(second)) { distance = 0.9f; //distance = (float) (parent_child_weight * (1.0 - (((float)common_watchers.size()) / MyUtils.mean(Arrays.asList(first.watchers.size(), second.watchers.size())))) // * (1.0 - ((similarly_owned_count + total_watchers) / Math.max(owners_to_repositories.get(first.owner).size(), 1))) + MyUtils.mean(similar_watcher_counts.values())); } else if (first.isRelated(second)) { distance = 0.5f; //distance = (float) (related_weight * (1.0 - (common_watchers.size() / MyUtils.mean(Arrays.asList(first.watchers.size(), second.watchers.size())))) // - ((common_watchers.size() / common_watchers_repo_diversity) / MyUtils.mean(Arrays.asList(first.watchers.size(), second.watchers.size())))); } else if (first.owner.equals(second.owner)) { distance = 0.4f * (1.0f - training_watcher.owner_distribution(second.owner)) / Math.max(1.0f, second.children.size()); //distance = (float) (same_owner_weight * (1.0 - (MyUtils.mean(Arrays.asList(((float)common_watchers.size()) / first.watchers.size(), ((float)common_watchers.size()) / second.watchers.size())))) // * (1.0 - ((similarly_owned_count + total_watchers) / owners_to_repositories.get(first.owner).size())) + MyUtils.mean(similar_watcher_counts.values())); } else { if (!common_watchers.isEmpty()) { distance = 0.7f; //final float first_common_watchers_ratio = ((float) common_watchers.size()) / first.watchers.size(); //final float second_common_watchers_ratio = ((float) common_watchers.size()) / second.watchers.size(); //distance = (float) (common_watcher_weight * (1.0 - (MyUtils.mean(Arrays.asList(first_common_watchers_ratio, second_common_watchers_ratio)))) // - (((float)common_watchers.size()) / common_watchers_repo_diversity) / MyUtils.mean(Arrays.asList(first.watchers.size(), second.watchers.size()))); } } int divisor = second.children.isEmpty() ? 1 : second.children.size(); return (leaf_repo_weight * lone_repo_weight * distance);// - MyUtils.mean(Arrays.asList(first_common_watchers_ratio, second_common_watchers_ratio)); /* common_watchers_repo_diversity = common_watchers.empty? ? 1 : common_watchers.collect {|w| training_watchers[w].repositories.size}.mean distance # Other factors for calculating distance: # - Ages of repositories # - Ancestry of two repositories (give higher weight if one of the repositories is the most popular by watchings and/or forks) # - # of forks # - watcher chains (e.g., repo a has watchers <2, 5>, repo b has watchers <5, 7>, repo c has watchers <7> . . . a & c may be slightly related. # - Language overlaps # - Size of repositories? # Also, look at weighting different attributes. Maybe use GA to optimize. */ } private void storeDistance(final Map<String, Map<String, Collection<Float>>> results, final Watcher watcher, final Repository repo, final Float distance) { if (results.get(watcher.id).get(repo.id) == null) { results.get(watcher.id).put(repo.id, new ArrayList<Float>()); } results.get(watcher.id).get(repo.id).add(distance); } private List<NeighborRegion> find_regions_with_most_cutpoints(final Watcher test_watcher, final Set<NeighborRegion> test_regions) { final Map<String, Collection<Integer>> related_region_counts = new HashMap<String, Collection<Integer>>(); // Look at each watcher in each of the test watcher's regions and find the other regions each of those watchers is in. for (final NeighborRegion watched_region : test_regions) { for (final Watcher related_watcher : watched_region.watchers) { for (final Repository related_repo : related_watcher.repositories) { final NeighborRegion related_region = related_repo.region; // Don't both adding in the region if we already know that it contains the test watcher. if (!related_region.watchers.contains(test_watcher)) { // Initialize counts list if necessary. if (related_region_counts.get(related_region.id) == null) { related_region_counts.put(related_region.id, new ArrayList<Integer>()); } // Add the cut point count. related_region_counts.get(related_region.id) .add(watched_region.cut_point_count(related_region)); } } } } final List<Map.Entry<String, Collection<Integer>>> sorted = MyUtils.sortMapByValues(related_region_counts, new IntegerMeanComparator()); final List<NeighborRegion> ret = new ArrayList<NeighborRegion>(); int upperBound = sorted.size() < TOP_REPOS_COUNT ? sorted.size() : TOP_REPOS_COUNT; for (int i = 0; i < upperBound; i++) { final String region_id = sorted.get(i).getKey(); ret.add(training_regions.get(region_id)); } return ret; } public NeighborRegion find_region(final Repository repo) { return training_regions.get(Repository.findRoot(repo).id); } /** * Chooses the k best predictions to make from all evaluated distances. * Evaluations is a hash of the form {watcher_id => {repo1_id => distance1, repo2_id => distance2}} * * @param knn * @param evaluations * @param k * @return */ public static Set<Watcher> predict(final NearestNeighbors knn, final Map<String, Map<String, Collection<Float>>> evaluations, final int k, final Map<String, Watcher> test_data) { final Set<Watcher> ret = new HashSet<Watcher>(); for (final Map.Entry<String, Map<String, Collection<Float>>> evaluation : evaluations.entrySet()) { final String user_id = evaluation.getKey(); final Watcher w = new Watcher(user_id); final Map<String, Collection<Float>> distances = evaluation.getValue(); if (!distances.isEmpty()) { final Watcher training_watcher = knn.training_watchers.get(user_id); final List<Map.Entry<String, Collection<Float>>> sorted = MyUtils.sortMapByValues(distances, new FloatMeanComparator()); int upperBound = distances.size() < k ? distances.size() : k; for (int i = 0; i < upperBound; i++) { // TODO (KJM 8/10/09) Only add repo if distance is below some threshold. final String repo_id = sorted.get(i).getKey(); final Repository repo = knn.training_repositories.get(repo_id); w.associate(repo); // Make sure the predicted watcher follows the same distribution as the training watcher. //final float normalize_factor = ((float) k) / training_watcher.repositories.size(); //if ((w.owner_counts.get(repo.owner).floatValue() / training_watcher.repositories.size()) / normalize_factor < training_watcher.owner_distribution(repo.owner)) //{ // w.repositories.remove(repo); // } } } ret.add(w); } return ret; } /** * Calculates accuracy between actual and predicted watchers. * * @param actual * @param predicted * @return */ public static float accuracy(final Watcher actual, final Watcher predicted) { if ((actual == null) || (predicted == null)) { return 0.0f; } if ((actual.repositories.isEmpty()) && (predicted.repositories.isEmpty())) { return 1.0f; } if ((actual.repositories.isEmpty()) && (!predicted.repositories.isEmpty())) { return 0.0f; } if ((actual.repositories.isEmpty()) || (predicted.repositories.isEmpty())) { return 0.0f; } int number_correct = CollectionUtils.intersection(actual.repositories, predicted.repositories).size(); int number_incorrect = CollectionUtils.subtract(predicted.repositories, actual.repositories).size(); // Rate the accuracy of the predictions, with a bias towards positive results. return ((float) number_correct) / actual.repositories.size(); // - ((float) (number_incorrect) / predicted.repositories.size(); } /** * Aggregates accuracies of evaluations of each item in the test set, yielding an overall accuracy score. * * @param test_set * @param predictions * @return */ public static float score(final DataSet test_set, final Set<Watcher> predictions) throws IOException { float number_correct = 0.0f; int total_repositories_to_predict = 0; // Look at each predicted answer for each watcher. If the prediction appears in the watcher's list, then it // was an accurate prediction. Otherwise, no score awarded. for (final Watcher prediction : predictions) { final Watcher actual = test_set.getWatchers().get(prediction.id); total_repositories_to_predict += actual.repositories.size(); for (final Repository r : prediction.repositories) { if (actual.repositories.contains(r)) { number_correct++; } } } return number_correct / total_repositories_to_predict; } private class IntegerMeanComparator implements Comparator<Map.Entry<String, Collection<Integer>>> { public int compare(final Map.Entry<String, Collection<Integer>> first, final Map.Entry<String, Collection<Integer>> second) { float firstAverage = MyUtils.mean(first.getValue()); float secondAverage = MyUtils.mean(second.getValue()); if (secondAverage > firstAverage) { return 1; } else if (secondAverage < firstAverage) { return -1; } return 0; } } private static class FloatMeanComparator implements Comparator<Map.Entry<String, Collection<Float>>> { public int compare(final Map.Entry<String, Collection<Float>> first, final Map.Entry<String, Collection<Float>> second) { float firstAverage = MyUtils.mean(first.getValue()); float secondAverage = MyUtils.mean(second.getValue()); if (firstAverage < secondAverage) { return -1; } else if (firstAverage > secondAverage) { return 1; } return 0; } } private static class NumberMeanComparator implements Comparator<Map.Entry<Watcher, Number>> { public int compare(final Map.Entry<Watcher, Number> first, final Map.Entry<Watcher, Number> second) { final float firstValue = first.getValue().floatValue(); final float secondValue = second.getValue().floatValue(); if (secondValue > firstValue) { return 1; } else if (secondValue < firstValue) { return -1; } return 0; } } }