org.easyrec.plugin.pearson.impl.PearsonServiceImpl.java Source code

Java tutorial

Introduction

Here is the source code for org.easyrec.plugin.pearson.impl.PearsonServiceImpl.java

Source

/*
 * Copyright 2011 Research Studios Austria Forschungsgesellschaft mBH
 *
 * This file is part of easyrec.
 *
 * easyrec is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * easyrec is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with easyrec.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.easyrec.plugin.pearson.impl;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.easyrec.model.core.ItemVO;
import org.easyrec.model.core.RatingVO;
import org.easyrec.model.core.TenantVO;
import org.easyrec.plugin.pearson.PearsonService;
import org.easyrec.plugin.pearson.model.Settings;
import org.easyrec.plugin.pearson.model.User;
import org.easyrec.plugin.pearson.model.UserAssoc;
import org.easyrec.plugin.pearson.model.Weight;
import org.easyrec.plugin.pearson.store.dao.LatestActionDAO;
import org.easyrec.plugin.pearson.store.dao.LatestActionDAO.RatedTogether;
import org.easyrec.plugin.pearson.store.dao.UserAssocDAO;
import org.easyrec.plugin.pearson.store.dao.UserDAO;
import org.easyrec.plugin.pearson.store.dao.WeightDAO;
import org.easyrec.service.core.TenantService;
import org.easyrec.service.domain.TypeMappingService;

import java.util.*;

public class PearsonServiceImpl implements PearsonService {
    private final LatestActionDAO latestActionDao;
    private final Settings settings;
    private final TenantService tenantService;
    private final TypeMappingService typeMappingService;
    private final UserAssocDAO userAssocDao;
    private final UserDAO userDao;
    private final WeightDAO weightDao;
    protected final Log logger = LogFactory.getLog(getClass());

    public PearsonServiceImpl(final Settings settings, final WeightDAO weightDao, final UserAssocDAO userAssocDao,
            final UserDAO userDao, final LatestActionDAO latestActionDao, final TenantService tenantService,
            final TypeMappingService typeMappingService) {
        this.settings = settings;
        this.weightDao = weightDao;
        this.userAssocDao = userAssocDao;
        this.userDao = userDao;
        this.latestActionDao = latestActionDao;
        this.tenantService = tenantService;
        this.typeMappingService = typeMappingService;
    }

    public void perform(final Integer tenantId) {
        List<TenantVO> tenants;

        if (tenantId != null) {
            final TenantVO tenant = tenantService.getTenantById(tenantId);

            tenants = new Vector<TenantVO>(1);
            tenants.add(tenant);
        } else
            tenants = tenantService.getAllTenants();

        for (final TenantVO tenant : tenants)
            performForTenant(tenant);
    }

    private void calculateWeights(final Integer tenantId, final Integer actionTypeId, final Integer itemTypeId,
            final List<User> users, final Map<Integer, Double> averageRatings) {
        final int userCount = users.size();

        final int perc25 = (int) (userCount * 0.25);
        final int perc50 = (int) (userCount * 0.5);
        final int perc75 = (int) (userCount * 0.75);

        for (int i = 0; i < userCount; i++) {
            final User activeUser = users.get(i);
            final double averageActive = averageRatings.get(activeUser.getUser());

            if (logger.isInfoEnabled()) {
                if (i == perc25)
                    logger.info("Weight calculation at 25%");
                if (i == perc50)
                    logger.info("Weight calculation at 50%");
                if (i == perc75)
                    logger.info("Weight calculation at 75%");

                if (i % 10 == 0)
                    logger.info(String.format("Weight calculation at user %d of %d", i, userCount));
            }

            for (int j = i + 1; j < userCount; j++) {
                final User otherUser = users.get(j);
                final List<RatedTogether<Integer, Integer>> ratedTogether = latestActionDao
                        .getItemsRatedTogetherByUsers(tenantId, itemTypeId, activeUser.getUser(),
                                otherUser.getUser(), actionTypeId);

                // users don't have common rated items
                if (ratedTogether.size() == 0)
                    continue;

                final double averageOther = averageRatings.get(otherUser.getUser());
                double frequency = 1.0;

                if (settings.isUseInverseUserFrequency()) {
                    frequency = userCount / ratedTogether.size();
                    frequency = Math.log10(frequency);

                    if (frequency == 0.0)
                        continue;
                }

                double frequencySum = 0.0;
                double expectedBoth = 0.0;
                double expectedActive = 0.0;
                double expectedOther = 0.0;
                double expectedActiveSquare = 0.0;
                double expectedOtherSquare = 0.0;

                for (final RatedTogether<Integer, Integer> rating : ratedTogether) {
                    final double ratingActive = rating.getRating1().getRatingValue();
                    final double ratingOther = rating.getRating2().getRatingValue();

                    frequencySum += frequency;
                    expectedBoth += frequency * ratingActive * ratingOther;
                    expectedActive += frequency * ratingActive;
                    expectedOther += frequency * ratingOther;
                    expectedActiveSquare += frequency * Math.pow(ratingActive, 2.0);
                    expectedOtherSquare += frequency * Math.pow(ratingOther, 2.0);
                }

                // TODO replace EX^2 - (EX)^2 with E((X-EX)^2) for better stability
                final double varianceActive = frequencySum * expectedActiveSquare - Math.pow(expectedActive, 2.0);
                final double varianceOther = frequencySum * expectedOtherSquare - Math.pow(expectedOther, 2.0);

                double numerator1 = frequencySum * expectedBoth;
                double numerator2 = expectedActive * expectedOther;
                final double denominator = Math.sqrt(varianceActive * varianceOther);

                numerator1 /= denominator;
                numerator2 /= denominator;

                final double weight = numerator1 - numerator2;

                if (Double.isNaN(weight) || Double.isInfinite(weight)) {
                    if (logger.isWarnEnabled())
                        logger.warn(String.format(
                                "Weight is %s for users %d and %d (vAct=%.2f, vOth=%.2f, Eact2=%.2f, Eoth2=%.2f, "
                                        + "Ebot=%.2f, Eact=%.2f, Eoth=%.2f, fre=%.2f fsum=%.2f, num1=%.2f, "
                                        + "numer2=%.2f, den=%.2f)",
                                Double.isNaN(weight) ? "NaN" : "Inf", i, j, varianceActive, varianceOther,
                                expectedActiveSquare, expectedOtherSquare, expectedBoth, expectedActive,
                                expectedOther, frequency, frequencySum, numerator1, numerator2, denominator));

                    continue;
                }

                final Weight weightObj = new Weight(activeUser, otherUser, weight);

                weightDao.insertOrUpdateWeightSymmetric(weightObj);
            }
        }
    }

    private Map<Integer, Double> getAverageUserRatingMap(final Integer tenantId, final Integer itemTypeId) {
        final List<RatingVO<Integer, Integer>> averageRatings = latestActionDao.getAverageRatingsForUser(tenantId,
                itemTypeId);
        final Map<Integer, Double> result = new HashMap<Integer, Double>(averageRatings.size());

        for (final RatingVO<Integer, Integer> averageRating : averageRatings)
            result.put(averageRating.getUser(), averageRating.getRatingValue());

        return result;
    }

    private void performForTenant(final TenantVO tenant) {
        if (tenant == null)
            throw new IllegalArgumentException("tenant is null");

        final Integer tenantId = tenant.getId();
        final Integer actionTypeId = typeMappingService.getIdOfActionType(tenantId, settings.getActionType());
        final Integer itemTypeId = typeMappingService.getIdOfItemType(tenantId, settings.getItemType());
        final Integer sourceTypeId = typeMappingService.getIdOfSourceType(tenantId, settings.getSourceType());
        final Integer minRatingValue = tenant.getRatingRangeMin();
        final Integer maxRatingValue = tenant.getRatingRangeMax();
        final Date changeDate = new Date();

        final List<User> users = userDao.getUsersForTenant(tenantId);
        final Map<Integer, Double> averageRatings = getAverageUserRatingMap(tenantId, itemTypeId);

        logger.info("Starting weight calculation.");
        Date start = new Date();

        // calculateWeights(tenantId, actionTypeId, itemTypeId, users, averageRatings);

        Date end = new Date();
        double time = (end.getTime() - start.getTime()) / 1000L;
        logger.info(String.format("Calculating weights for %s took %.2f seconds", tenant.getStringId(), time));

        logger.info("Starting predictions.");
        start = new Date();

        predict(tenantId, actionTypeId, itemTypeId, sourceTypeId, changeDate, users, averageRatings, minRatingValue,
                maxRatingValue);

        end = new Date();
        time = (end.getTime() - start.getTime()) / 1000L;
        logger.info(String.format("Calculating USER-ITEM predictions for %s took %.2f seconds",
                tenant.getStringId(), time));
    }

    private void predict(final Integer tenantId, final Integer actionTypeId, final Integer itemTypeId,
            final Integer sourceTypeId, final Date changeDate, final List<User> users,
            final Map<Integer, Double> averageRatings, final Integer minRatingValue, final Integer maxRatingValue) {
        // final List<ItemVO<Integer, Integer>> items = latestActionDao.getAvailableItemsForTenant(tenantId,
        // itemTypeId);

        final double caseAmplification = settings.getCaseAmplification();
        final boolean useCaseAmplification = settings.getCaseAmplification() != null;

        int cur = 0;
        final int perc25 = (int) (users.size() * 0.25);
        final int perc50 = (int) (users.size() * 0.5);
        final int perc75 = (int) (users.size() * 0.75);

        for (final User activeUser : users) {
            if (logger.isInfoEnabled()) {
                if (cur == perc25)
                    logger.info("Predictions at 25%");
                if (cur == perc50)
                    logger.info("Predictions at 50%");
                if (cur == perc75)
                    logger.info("Predictions at 75%");

                logger.info(String.format("Predictions at user %d of %d", cur, users.size()));
            }

            cur++;

            // final List<Weight> weights = weightDao.getWeightsForUser1(tenantId, activeUser.getUser());

            // if (weights.size() == 0) {
            // if (logger.isInfoEnabled())
            // logger.info(String.format(
            // "Couldn't calculate prediction for user %d because no weights were present",
            // activeUser.getUser()));
            //
            // continue;
            // }

            final List<ItemVO<Integer, Integer>> items;

            if (settings.getTestDataSourceType() == null)
                items = latestActionDao.getItemsNotRatedByUser(tenantId, activeUser.getUser(), itemTypeId);
            else
                // we are only predicting for the test-set
                items = userAssocDao.getItemsAssociatedToUser(tenantId, activeUser.getUser(), itemTypeId,
                        settings.getTestDataSourceType());

            for (final ItemVO<Integer, Integer> item : items) {
                if (latestActionDao.didUserRateItem(activeUser.getUser(), item, actionTypeId))
                    continue;

                double kappa = 0.0;
                double weightedRatings = 0.0;

                final List<Weight> weights = weightDao.getWeightsForUser1AndItem(tenantId, activeUser.getUser(),
                        item.getItem(), item.getType());

                if (weights.size() == 0) {
                    if (logger.isInfoEnabled())
                        logger.info(String.format(
                                "Couldn't calculate prediction for user %d and item %d because no weights were present",
                                activeUser.getUser(), item.getItem()));

                    continue;
                }

                // TODO get users that rated item
                // TODO get weights for users that rated item
                for (final Weight weight : weights) {
                    double currentWeight = weight.getWeight();
                    final User otherUser = weight.getUser2();

                    final List<RatingVO<Integer, Integer>> ratingsOther = latestActionDao.getLatestRatingsForTenant(
                            tenantId, itemTypeId, item.getItem(), otherUser.getUser(), null);

                    if (ratingsOther.size() > 1 && logger.isWarnEnabled())
                        logger.warn("    There shouldn't be more than 1 rating");

                    double ratingOther = 0.0;

                    if (ratingsOther.size() == 1)
                        ratingOther = ratingsOther.get(0).getRatingValue();
                    else
                        // the other user didn't rate the current item
                        continue;

                    if (useCaseAmplification)
                        if (currentWeight >= 0)
                            currentWeight = Math.pow(currentWeight, caseAmplification);
                        else
                            currentWeight = -Math.pow(-currentWeight, caseAmplification);

                    final double averageRatingOther = averageRatings.get(otherUser.getUser());
                    kappa += Math.abs(currentWeight);
                    weightedRatings += currentWeight * (ratingOther - averageRatingOther);
                }

                if (kappa == 0.0 || Double.isNaN(kappa) || Double.isInfinite(kappa)) {
                    if (logger.isInfoEnabled())
                        logger.info(String.format("    Prediction for user %d item %d failed (kappa=%f)",
                                activeUser.getUser(), item.getItem(), kappa));

                    continue;
                }

                double prediction = weightedRatings / kappa;

                if (Double.isNaN(weightedRatings) || Double.isInfinite(weightedRatings)) {
                    if (logger.isInfoEnabled())
                        logger.info(String.format(
                                "    Prediction for user %d item %d failed (prediction=%f, weightedRatings=%f, kappa=%f)",
                                activeUser.getUser(), item.getItem(), prediction, weightedRatings, kappa));

                    continue;
                }

                final double averageRatingActive = averageRatings.get(activeUser.getUser());
                prediction += averageRatingActive;

                if (settings.isNormalizePredictions()) {
                    prediction = Math.max(prediction, minRatingValue);
                    prediction = Math.min(prediction, maxRatingValue);
                }

                final UserAssoc userAssoc = new UserAssoc(prediction, changeDate, item, sourceTypeId, tenantId,
                        activeUser.getUser());
                userAssocDao.insertOrUpdateUserAssoc(userAssoc);
            }
        }
    }

}