List of usage examples for org.apache.mahout.math Vector assign
Vector assign(DoubleDoubleFunction f, double y);
From source file:mlbench.bayes.train.WeightSummer.java
License:Apache License
@SuppressWarnings("deprecation") public static void main(String[] args) throws MPI_D_Exception, IOException, MPIException { parseArgs(args);/*from ww w. j a va 2 s .c o m*/ HashMap<String, String> conf = new HashMap<String, String>(); initConf(conf); MPI_D.Init(args, MPI_D.Mode.Common, conf); if (MPI_D.COMM_BIPARTITE_O != null) { int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_O); int size = MPI_D.Comm_size(MPI_D.COMM_BIPARTITE_O); FileSplit[] inputs = DataMPIUtil.HDFSDataLocalLocator.getTaskInputs(MPI_D.COMM_BIPARTITE_O, (JobConf) config, inDir, rank); Vector weightsPerFeature = null; Vector weightsPerLabel = new DenseVector(labNum); for (int i = 0; i < inputs.length; i++) { FileSplit fsplit = inputs[i]; SequenceFileRecordReader<IntWritable, VectorWritable> kvrr = new SequenceFileRecordReader<>(config, fsplit); IntWritable index = kvrr.createKey(); VectorWritable value = kvrr.createValue(); while (kvrr.next(index, value)) { Vector instance = value.get(); if (weightsPerFeature == null) { weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements()); } int label = index.get(); weightsPerFeature.assign(instance, Functions.PLUS); weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum()); } } if (weightsPerFeature != null) { MPI_D.Send(new Text(WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature)); MPI_D.Send(new Text(WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel)); } } else if (MPI_D.COMM_BIPARTITE_A != null) { int rank = MPI_D.Comm_rank(MPI_D.COMM_BIPARTITE_A); config.set(MAPRED_OUTPUT_DIR, outDirW); config.set("mapred.task.id", DataMPIUtil.getHadoopTaskAttemptID().toString().toString()); ((JobConf) config).setOutputKeyClass(Text.class); ((JobConf) config).setOutputValueClass(VectorWritable.class); TaskAttemptContext taskContext = new TaskAttemptContextImpl(config, DataMPIUtil.getHadoopTaskAttemptID()); SequenceFileOutputFormat<Text, VectorWritable> outfile = new SequenceFileOutputFormat<>(); FileSystem fs = FileSystem.get(config); Path output = new Path(config.get(MAPRED_OUTPUT_DIR)); FileOutputCommitter fcommitter = new FileOutputCommitter(output, taskContext); RecordWriter<Text, VectorWritable> outrw = null; try { fcommitter.setupJob(taskContext); outrw = outfile.getRecordWriter(fs, (JobConf) config, getOutputName(rank), null); } catch (IOException e) { e.printStackTrace(); System.err.println("ERROR: Please set the HDFS configuration properly\n"); System.exit(-1); } Text key = null, newKey = null; VectorWritable point = null, newPoint = null; Vector vector = null; Object[] vals = MPI_D.Recv(); while (vals != null) { newKey = (Text) vals[0]; newPoint = (VectorWritable) vals[1]; if (key == null && point == null) { } else if (!key.equals(newKey)) { outrw.write(key, new VectorWritable(vector)); vector = null; } if (vector == null) { vector = newPoint.get(); } else { vector.assign(newPoint.get(), Functions.PLUS); } key = newKey; point = newPoint; vals = MPI_D.Recv(); } if (newKey != null && newPoint != null) { outrw.write(key, new VectorWritable(vector)); } outrw.close(null); if (fcommitter.needsTaskCommit(taskContext)) { fcommitter.commitTask(taskContext); } MPI_D.COMM_BIPARTITE_A.Barrier(); if (rank == 0) { Path resOut = new Path(outDir); NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(new Path(outDir), config); naiveBayesModel.serialize(resOut, config); } } MPI_D.Finalize(); }