org.plista.kornakapi.core.training.MultithreadedItembasedInMemoryTrainer.java Source code

Java tutorial

Introduction

Here is the source code for org.plista.kornakapi.core.training.MultithreadedItembasedInMemoryTrainer.java

Source

/**
 * Copyright 2012 plista GmbH  (http://www.plista.com/)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and limitations under the License.
 */

package org.plista.kornakapi.core.training;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
import org.plista.kornakapi.core.config.ItembasedRecommenderConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/** a multi-threaded trainer for item kNN recommenders */
public class MultithreadedItembasedInMemoryTrainer extends AbstractTrainer {

    private final ItembasedRecommenderConfig conf;

    private static final Logger log = LoggerFactory.getLogger(MultithreadedItembasedInMemoryTrainer.class);

    public MultithreadedItembasedInMemoryTrainer(ItembasedRecommenderConfig conf) {
        super(conf);
        this.conf = conf;
    }

    @Override
    protected void doTrain(File targetFile, DataModel inmemoryData, int numProcessors) throws IOException {
        BufferedWriter writer = null;

        ExecutorService executorService = Executors.newFixedThreadPool(numProcessors + 1);

        try {

            ItemSimilarity similarity = (ItemSimilarity) Class.forName(conf.getSimilarityClass())
                    .getConstructor(DataModel.class).newInstance(inmemoryData);

            ItemBasedRecommender trainer = new GenericItemBasedRecommender(inmemoryData, similarity);

            writer = new BufferedWriter(new FileWriter(targetFile));

            int batchSize = 100;
            int numItems = inmemoryData.getNumItems();

            List<long[]> itemIDBatches = queueItemIDsInBatches(inmemoryData.getItemIDs(), numItems, batchSize);

            log.info("Queued {} items in {} batches", numItems, itemIDBatches.size());

            BlockingQueue<long[]> itemsIDsToProcess = new LinkedBlockingQueue<long[]>(itemIDBatches);
            BlockingQueue<String> output = new LinkedBlockingQueue<String>();

            AtomicInteger numActiveWorkers = new AtomicInteger(numProcessors);
            for (int n = 0; n < numProcessors; n++) {
                executorService.execute(new SimilarItemsWorker(n, itemsIDsToProcess, output, trainer,
                        conf.getSimilarItemsPerItem(), numActiveWorkers));
            }
            executorService.execute(new OutputWriter(output, writer, numActiveWorkers));

        } catch (Exception e) {
            throw new IOException(e);
        } finally {
            executorService.shutdown();
            try {
                executorService.awaitTermination(6, TimeUnit.HOURS);
            } catch (InterruptedException e) {

            }
            Closeables.closeQuietly(writer);
        }
    }

    private List<long[]> queueItemIDsInBatches(LongPrimitiveIterator itemIDs, int numItems, int batchSize) {
        List<long[]> itemIDBatches = Lists.newArrayListWithCapacity(numItems / batchSize);

        long[] batch = new long[batchSize];
        int pos = 0;
        while (itemIDs.hasNext()) {
            if (pos == batchSize) {
                itemIDBatches.add(batch.clone());
                pos = 0;
            }
            batch[pos] = itemIDs.nextLong();
            pos++;
        }
        int nonQueuedItemIDs = batchSize - pos;
        if (nonQueuedItemIDs > 0) {
            long[] lastBatch = new long[nonQueuedItemIDs];
            System.arraycopy(batch, 0, lastBatch, 0, nonQueuedItemIDs);
            itemIDBatches.add(lastBatch);
        }
        return itemIDBatches;
    }

    static class OutputWriter implements Runnable {

        private final BlockingQueue<String> output;
        private final BufferedWriter writer;
        private final AtomicInteger numActiveWorkers;

        OutputWriter(BlockingQueue<String> output, BufferedWriter writer, AtomicInteger numActiveWorkers) {
            this.output = output;
            this.writer = writer;
            this.numActiveWorkers = numActiveWorkers;
        }

        @Override
        public void run() {
            while (numActiveWorkers.get() != 0) {
                try {
                    String lines = output.poll(10, TimeUnit.MILLISECONDS);
                    if (null != lines) {
                        writer.write(lines);
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    static class SimilarItemsWorker implements Runnable {

        private final int number;
        private final BlockingQueue<long[]> itemIDBatches;
        private final BlockingQueue<String> output;
        private final ItemBasedRecommender trainer;
        private final int howMany;
        private final AtomicInteger numActiveWorkers;

        SimilarItemsWorker(int number, BlockingQueue<long[]> itemIDBatches, BlockingQueue<String> output,
                ItemBasedRecommender trainer, int howMany, AtomicInteger numActiveWorkers) {
            this.number = number;
            this.itemIDBatches = itemIDBatches;
            this.output = output;
            this.trainer = trainer;
            this.howMany = howMany;
            this.numActiveWorkers = numActiveWorkers;
        }

        @Override
        public void run() {
            int numBatchesProcessed = 0;
            while (!itemIDBatches.isEmpty()) {
                try {
                    long[] itemIDBatch = itemIDBatches.take();
                    StringBuilder lines = new StringBuilder();

                    for (long itemID : itemIDBatch) {
                        Iterable<RecommendedItem> similarItems = trainer.mostSimilarItems(itemID, howMany);

                        for (RecommendedItem similarItem : similarItems) {
                            lines.append(itemID).append(',').append(similarItem.getItemID()).append(',')
                                    .append(similarItem.getValue()).append('\n');
                        }
                    }

                    output.offer(lines.toString());

                    if (++numBatchesProcessed % 5 == 0) {
                        log.info("worker {} processed {} batches", number, numBatchesProcessed);
                    }

                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            log.info("worker {} processed {} batches. done.", number, numBatchesProcessed);
            numActiveWorkers.decrementAndGet();
        }
    }
}