org.grouplens.lenskit.eval.traintest.ExternalEvalJob.java Source code

Java tutorial

Introduction

Here is the source code for org.grouplens.lenskit.eval.traintest.ExternalEvalJob.java

Source

/*
 * LensKit, an open source recommender systems toolkit.
 * Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
 * Work on LensKit has been funded by the National Science Foundation under
 * grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 51
 * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import org.apache.commons.lang3.StringUtils;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.EventDAO;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.eval.data.CSVDataSource;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataSet;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelector;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.util.DelimitedTextCursor;
import org.grouplens.lenskit.util.io.LoggingStreamSlurper;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
import java.util.UUID;

/**
 * Job implementation for exeternal algorithms.
 * @author <a href="http://www.grouplens.org">GroupLens Research</a>
 */
class ExternalEvalJob extends TrainTestJob {
    private static final Logger logger = LoggerFactory.getLogger(ExternalEvalJob.class);
    private final ExternalAlgorithm algorithm;
    private final UUID key;
    private Long2ObjectMap<SparseVector> userPredictions;
    // References to hold on to the user event DAOs for test users
    private UserEventDAO userTrainingEvents, userTestEvents;

    public ExternalEvalJob(TrainTestEvalTask task, @Nonnull ExternalAlgorithm algo, @Nonnull TTDataSet ds) {
        super(task, algo, ds);
        algorithm = algo;
        key = UUID.randomUUID();
    }

    @Override
    protected void buildRecommender() throws RecommenderBuildException {
        Preconditions.checkState(userPredictions == null, "recommender already built");
        File dir = getStagingDir();
        logger.info("using output/staging directory {}", dir);
        if (!dir.exists()) {
            logger.info("creating directory {}", dir);
            dir.mkdirs();
        }

        final File train;
        try {
            train = trainingFile(dataSet);
        } catch (IOException e) {
            throw new RecommenderBuildException("error preparing training file", e);
        }
        final File test;
        try {
            test = testFile(dataSet);
        } catch (IOException e) {
            throw new RecommenderBuildException("error preparing test file", e);
        }
        final File output = getFile("predictions.csv");
        List<String> args = Lists.transform(algorithm.getCommand(), new Function<String, String>() {
            @Nullable
            @Override
            public String apply(@Nullable String input) {
                if (input == null) {
                    throw new NullPointerException("command element");
                }
                String s = input.replace("{OUTPUT}", output.getAbsolutePath());
                s = s.replace("{TRAIN_DATA}", train.getAbsolutePath());
                s = s.replace("{TEST_DATA}", test.getAbsolutePath());
                return s;
            }
        });

        logger.info("running {}", StringUtils.join(args, " "));

        Process proc;
        try {
            proc = new ProcessBuilder().command(args).directory(algorithm.getWorkDir()).start();
        } catch (IOException e) {
            throw new RecommenderBuildException("error creating process", e);
        }
        Thread listen = new LoggingStreamSlurper("external-algo", proc.getErrorStream(), logger, "external: ");
        listen.run();

        int result = -1;
        boolean done = false;
        while (!done) {
            try {
                result = proc.waitFor();
                done = true;
            } catch (InterruptedException e) {
                logger.info("thread interrupted, killing subprocess");
                proc.destroy();
                throw new RecommenderBuildException("recommender build interrupted", e);
            }
        }

        if (result != 0) {
            logger.error("external command exited with status {}", result);
            throw new RecommenderBuildException("recommender exited with code " + result);
        }

        Long2ObjectMap<SparseVector> vectors;
        try {
            vectors = readPredictions(output);
        } catch (FileNotFoundException e) {
            logger.error("cannot find expected output file {}", output);
            throw new RecommenderBuildException("recommender produced no output", e);
        }

        userPredictions = vectors;

        userTrainingEvents = dataSet.getTrainingData().getUserEventDAO();
        userTestEvents = dataSet.getTestData().getUserEventDAO();
    }

    @Override
    protected TestUser getUserResults(long uid) {
        Preconditions.checkState(userPredictions != null, "recommender not built");
        return new TestUserImpl(uid);
    }

    @Override
    protected void cleanup() {
        userPredictions = null;
        userTrainingEvents = null;
        userTestEvents = null;
    }

    private File getStagingDir() {
        String dirName = String.format("%s-%s", algorithm.getName(), key);
        return new File(algorithm.getWorkDir(), dirName);
    }

    /**
     * Create a fully-qualified algorithm file name.
     * @param fn The file name.
     * @return A file in the working directory.
     */
    private File getFile(String fn) {
        return new File(getStagingDir(), fn);
    }

    private File trainingFile(TTDataSet data) throws IOException {
        try {
            GenericTTDataSet gds = (GenericTTDataSet) data;
            CSVDataSource csv = (CSVDataSource) gds.getTrainingData();
            if (",".equals(csv.getDelimiter())) {
                File file = csv.getFile();
                logger.debug("using training file {}", file);
                return file;
            }
        } catch (ClassCastException e) {
            /* No-op - this is fine, we will make a file. */
        }
        File file = makeCSV(data.getTrainingDAO(), getFile("train.csv"), true);
        logger.debug("wrote training file {}", file);
        return file;
    }

    private File testFile(TTDataSet data) throws IOException {
        File file = makeCSV(data.getTestDAO(), getFile("test.csv"), false);
        logger.debug("wrote test file {}", file);
        return file;
    }

    private File makeCSV(EventDAO dao, File file, boolean writeRatings) throws IOException {
        // TODO Make this not re-copy data unnecessarily
        Object[] row = new Object[writeRatings ? 3 : 2];
        TableWriter table = CSVWriter.open(file, null);
        try {
            Cursor<Rating> ratings = dao.streamEvents(Rating.class);
            try {
                for (Rating r : ratings.fast()) {
                    Preference p = r.getPreference();
                    if (p != null) {
                        row[0] = r.getUserId();
                        row[1] = r.getItemId();
                        if (writeRatings) {
                            row[2] = p.getValue();
                        }
                        table.writeRow(row);
                    }
                }
            } finally {
                ratings.close();
            }
        } finally {
            table.close();
        }
        return file;
    }

    private Long2ObjectMap<SparseVector> readPredictions(File predFile)
            throws FileNotFoundException, RecommenderBuildException {
        Long2ObjectMap<Long2DoubleMap> data = new Long2ObjectOpenHashMap<Long2DoubleMap>();
        Cursor<String[]> cursor = new DelimitedTextCursor(predFile, algorithm.getOutputDelimiter());
        try {
            int n = 0;
            for (String[] row : cursor) {
                n++;
                if (row.length == 0 || row.length == 1 && row[0].equals("")) {
                    continue;
                }
                if (row.length < 3) {
                    logger.error("predictions line {}: invalid row {}", n, StringUtils.join(row, ","));
                    throw new RecommenderBuildException("invalid prediction row");
                }
                long uid = Long.parseLong(row[0]);
                long iid = Long.parseLong(row[1]);
                double pred = Double.parseDouble(row[2]);
                Long2DoubleMap user = data.get(uid);
                if (user == null) {
                    user = new Long2DoubleOpenHashMap();
                    data.put(uid, user);
                }
                user.put(iid, pred);
            }
        } finally {
            cursor.close();
        }
        Long2ObjectMap<SparseVector> vectors = new Long2ObjectOpenHashMap<SparseVector>(data.size());
        for (Long2ObjectMap.Entry<Long2DoubleMap> entry : CollectionUtils.fast(data.long2ObjectEntrySet())) {
            vectors.put(entry.getLongKey(), ImmutableSparseVector.create(entry.getValue()));
        }
        return vectors;
    }

    /**
     * External algorithmInfo implementation of TestUser.
     * @author <a href="http://www.grouplens.org">GroupLens Research</a>
     */
    class TestUserImpl extends AbstractTestUser {
        private final long userId;

        public TestUserImpl(long uid) {
            userId = uid;
        }

        @Override
        public UserHistory<Event> getTrainHistory() {
            UserHistory<Event> events = userTrainingEvents.getEventsForUser(userId);
            if (events == null) {
                return History.forUser(userId); //Creates an empty history for this particular user.
            } else {
                return events;
            }
        }

        @Override
        public UserHistory<Event> getTestHistory() {
            return userTestEvents.getEventsForUser(userId);
        }

        @Override
        public SparseVector getPredictions() {
            return userPredictions.get(userId);
        }

        @Override
        public List<ScoredId> getRecommendations(int n, ItemSelector candSel, ItemSelector exclSel) {
            return null;
        }

        @Override
        public Recommender getRecommender() {
            return null;
        }
    }
}