Java tutorial
/* * Copyright 2011 Research Studios Austria Forschungsgesellschaft mBH * * This file is part of easyrec. * * easyrec is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * easyrec is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with easyrec. If not, see <http://www.gnu.org/licenses/>. */ package org.easyrec.plugin.pearson.impl; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.easyrec.model.core.ItemVO; import org.easyrec.model.core.RatingVO; import org.easyrec.model.core.TenantVO; import org.easyrec.plugin.pearson.PearsonService; import org.easyrec.plugin.pearson.model.Settings; import org.easyrec.plugin.pearson.model.User; import org.easyrec.plugin.pearson.model.UserAssoc; import org.easyrec.plugin.pearson.model.Weight; import org.easyrec.plugin.pearson.store.dao.LatestActionDAO; import org.easyrec.plugin.pearson.store.dao.LatestActionDAO.RatedTogether; import org.easyrec.plugin.pearson.store.dao.UserAssocDAO; import org.easyrec.plugin.pearson.store.dao.UserDAO; import org.easyrec.plugin.pearson.store.dao.WeightDAO; import org.easyrec.service.core.TenantService; import org.easyrec.service.domain.TypeMappingService; import java.util.*; public class PearsonServiceImpl implements PearsonService { private final LatestActionDAO latestActionDao; private final Settings settings; private final TenantService tenantService; private final TypeMappingService typeMappingService; private final UserAssocDAO userAssocDao; private final UserDAO userDao; private final WeightDAO weightDao; protected final Log logger = LogFactory.getLog(getClass()); public PearsonServiceImpl(final Settings settings, final WeightDAO weightDao, final UserAssocDAO userAssocDao, final UserDAO userDao, final LatestActionDAO latestActionDao, final TenantService tenantService, final TypeMappingService typeMappingService) { this.settings = settings; this.weightDao = weightDao; this.userAssocDao = userAssocDao; this.userDao = userDao; this.latestActionDao = latestActionDao; this.tenantService = tenantService; this.typeMappingService = typeMappingService; } public void perform(final Integer tenantId) { List<TenantVO> tenants; if (tenantId != null) { final TenantVO tenant = tenantService.getTenantById(tenantId); tenants = new Vector<TenantVO>(1); tenants.add(tenant); } else tenants = tenantService.getAllTenants(); for (final TenantVO tenant : tenants) performForTenant(tenant); } private void calculateWeights(final Integer tenantId, final Integer actionTypeId, final Integer itemTypeId, final List<User> users, final Map<Integer, Double> averageRatings) { final int userCount = users.size(); final int perc25 = (int) (userCount * 0.25); final int perc50 = (int) (userCount * 0.5); final int perc75 = (int) (userCount * 0.75); for (int i = 0; i < userCount; i++) { final User activeUser = users.get(i); final double averageActive = averageRatings.get(activeUser.getUser()); if (logger.isInfoEnabled()) { if (i == perc25) logger.info("Weight calculation at 25%"); if (i == perc50) logger.info("Weight calculation at 50%"); if (i == perc75) logger.info("Weight calculation at 75%"); if (i % 10 == 0) logger.info(String.format("Weight calculation at user %d of %d", i, userCount)); } for (int j = i + 1; j < userCount; j++) { final User otherUser = users.get(j); final List<RatedTogether<Integer, Integer>> ratedTogether = latestActionDao .getItemsRatedTogetherByUsers(tenantId, itemTypeId, activeUser.getUser(), otherUser.getUser(), actionTypeId); // users don't have common rated items if (ratedTogether.size() == 0) continue; final double averageOther = averageRatings.get(otherUser.getUser()); double frequency = 1.0; if (settings.isUseInverseUserFrequency()) { frequency = userCount / ratedTogether.size(); frequency = Math.log10(frequency); if (frequency == 0.0) continue; } double frequencySum = 0.0; double expectedBoth = 0.0; double expectedActive = 0.0; double expectedOther = 0.0; double expectedActiveSquare = 0.0; double expectedOtherSquare = 0.0; for (final RatedTogether<Integer, Integer> rating : ratedTogether) { final double ratingActive = rating.getRating1().getRatingValue(); final double ratingOther = rating.getRating2().getRatingValue(); frequencySum += frequency; expectedBoth += frequency * ratingActive * ratingOther; expectedActive += frequency * ratingActive; expectedOther += frequency * ratingOther; expectedActiveSquare += frequency * Math.pow(ratingActive, 2.0); expectedOtherSquare += frequency * Math.pow(ratingOther, 2.0); } // TODO replace EX^2 - (EX)^2 with E((X-EX)^2) for better stability final double varianceActive = frequencySum * expectedActiveSquare - Math.pow(expectedActive, 2.0); final double varianceOther = frequencySum * expectedOtherSquare - Math.pow(expectedOther, 2.0); double numerator1 = frequencySum * expectedBoth; double numerator2 = expectedActive * expectedOther; final double denominator = Math.sqrt(varianceActive * varianceOther); numerator1 /= denominator; numerator2 /= denominator; final double weight = numerator1 - numerator2; if (Double.isNaN(weight) || Double.isInfinite(weight)) { if (logger.isWarnEnabled()) logger.warn(String.format( "Weight is %s for users %d and %d (vAct=%.2f, vOth=%.2f, Eact2=%.2f, Eoth2=%.2f, " + "Ebot=%.2f, Eact=%.2f, Eoth=%.2f, fre=%.2f fsum=%.2f, num1=%.2f, " + "numer2=%.2f, den=%.2f)", Double.isNaN(weight) ? "NaN" : "Inf", i, j, varianceActive, varianceOther, expectedActiveSquare, expectedOtherSquare, expectedBoth, expectedActive, expectedOther, frequency, frequencySum, numerator1, numerator2, denominator)); continue; } final Weight weightObj = new Weight(activeUser, otherUser, weight); weightDao.insertOrUpdateWeightSymmetric(weightObj); } } } private Map<Integer, Double> getAverageUserRatingMap(final Integer tenantId, final Integer itemTypeId) { final List<RatingVO<Integer, Integer>> averageRatings = latestActionDao.getAverageRatingsForUser(tenantId, itemTypeId); final Map<Integer, Double> result = new HashMap<Integer, Double>(averageRatings.size()); for (final RatingVO<Integer, Integer> averageRating : averageRatings) result.put(averageRating.getUser(), averageRating.getRatingValue()); return result; } private void performForTenant(final TenantVO tenant) { if (tenant == null) throw new IllegalArgumentException("tenant is null"); final Integer tenantId = tenant.getId(); final Integer actionTypeId = typeMappingService.getIdOfActionType(tenantId, settings.getActionType()); final Integer itemTypeId = typeMappingService.getIdOfItemType(tenantId, settings.getItemType()); final Integer sourceTypeId = typeMappingService.getIdOfSourceType(tenantId, settings.getSourceType()); final Integer minRatingValue = tenant.getRatingRangeMin(); final Integer maxRatingValue = tenant.getRatingRangeMax(); final Date changeDate = new Date(); final List<User> users = userDao.getUsersForTenant(tenantId); final Map<Integer, Double> averageRatings = getAverageUserRatingMap(tenantId, itemTypeId); logger.info("Starting weight calculation."); Date start = new Date(); // calculateWeights(tenantId, actionTypeId, itemTypeId, users, averageRatings); Date end = new Date(); double time = (end.getTime() - start.getTime()) / 1000L; logger.info(String.format("Calculating weights for %s took %.2f seconds", tenant.getStringId(), time)); logger.info("Starting predictions."); start = new Date(); predict(tenantId, actionTypeId, itemTypeId, sourceTypeId, changeDate, users, averageRatings, minRatingValue, maxRatingValue); end = new Date(); time = (end.getTime() - start.getTime()) / 1000L; logger.info(String.format("Calculating USER-ITEM predictions for %s took %.2f seconds", tenant.getStringId(), time)); } private void predict(final Integer tenantId, final Integer actionTypeId, final Integer itemTypeId, final Integer sourceTypeId, final Date changeDate, final List<User> users, final Map<Integer, Double> averageRatings, final Integer minRatingValue, final Integer maxRatingValue) { // final List<ItemVO<Integer, Integer>> items = latestActionDao.getAvailableItemsForTenant(tenantId, // itemTypeId); final double caseAmplification = settings.getCaseAmplification(); final boolean useCaseAmplification = settings.getCaseAmplification() != null; int cur = 0; final int perc25 = (int) (users.size() * 0.25); final int perc50 = (int) (users.size() * 0.5); final int perc75 = (int) (users.size() * 0.75); for (final User activeUser : users) { if (logger.isInfoEnabled()) { if (cur == perc25) logger.info("Predictions at 25%"); if (cur == perc50) logger.info("Predictions at 50%"); if (cur == perc75) logger.info("Predictions at 75%"); logger.info(String.format("Predictions at user %d of %d", cur, users.size())); } cur++; // final List<Weight> weights = weightDao.getWeightsForUser1(tenantId, activeUser.getUser()); // if (weights.size() == 0) { // if (logger.isInfoEnabled()) // logger.info(String.format( // "Couldn't calculate prediction for user %d because no weights were present", // activeUser.getUser())); // // continue; // } final List<ItemVO<Integer, Integer>> items; if (settings.getTestDataSourceType() == null) items = latestActionDao.getItemsNotRatedByUser(tenantId, activeUser.getUser(), itemTypeId); else // we are only predicting for the test-set items = userAssocDao.getItemsAssociatedToUser(tenantId, activeUser.getUser(), itemTypeId, settings.getTestDataSourceType()); for (final ItemVO<Integer, Integer> item : items) { if (latestActionDao.didUserRateItem(activeUser.getUser(), item, actionTypeId)) continue; double kappa = 0.0; double weightedRatings = 0.0; final List<Weight> weights = weightDao.getWeightsForUser1AndItem(tenantId, activeUser.getUser(), item.getItem(), item.getType()); if (weights.size() == 0) { if (logger.isInfoEnabled()) logger.info(String.format( "Couldn't calculate prediction for user %d and item %d because no weights were present", activeUser.getUser(), item.getItem())); continue; } // TODO get users that rated item // TODO get weights for users that rated item for (final Weight weight : weights) { double currentWeight = weight.getWeight(); final User otherUser = weight.getUser2(); final List<RatingVO<Integer, Integer>> ratingsOther = latestActionDao.getLatestRatingsForTenant( tenantId, itemTypeId, item.getItem(), otherUser.getUser(), null); if (ratingsOther.size() > 1 && logger.isWarnEnabled()) logger.warn(" There shouldn't be more than 1 rating"); double ratingOther = 0.0; if (ratingsOther.size() == 1) ratingOther = ratingsOther.get(0).getRatingValue(); else // the other user didn't rate the current item continue; if (useCaseAmplification) if (currentWeight >= 0) currentWeight = Math.pow(currentWeight, caseAmplification); else currentWeight = -Math.pow(-currentWeight, caseAmplification); final double averageRatingOther = averageRatings.get(otherUser.getUser()); kappa += Math.abs(currentWeight); weightedRatings += currentWeight * (ratingOther - averageRatingOther); } if (kappa == 0.0 || Double.isNaN(kappa) || Double.isInfinite(kappa)) { if (logger.isInfoEnabled()) logger.info(String.format(" Prediction for user %d item %d failed (kappa=%f)", activeUser.getUser(), item.getItem(), kappa)); continue; } double prediction = weightedRatings / kappa; if (Double.isNaN(weightedRatings) || Double.isInfinite(weightedRatings)) { if (logger.isInfoEnabled()) logger.info(String.format( " Prediction for user %d item %d failed (prediction=%f, weightedRatings=%f, kappa=%f)", activeUser.getUser(), item.getItem(), prediction, weightedRatings, kappa)); continue; } final double averageRatingActive = averageRatings.get(activeUser.getUser()); prediction += averageRatingActive; if (settings.isNormalizePredictions()) { prediction = Math.max(prediction, minRatingValue); prediction = Math.min(prediction, maxRatingValue); } final UserAssoc userAssoc = new UserAssoc(prediction, changeDate, item, sourceTypeId, tenantId, activeUser.getUser()); userAssocDao.insertOrUpdateUserAssoc(userAssoc); } } } }