Typically, the fork/join framework is suitable in a situation where
The following four classes are the main classes in fork/join framework:
ForkJoinPool ForkJoinTask RecursiveAction RecursiveTask
ForkJoinPool class represents a thread pool.
ForkJoinTask class represents a task.
Abstract ForkJoinTask class has two concrete subclasses: RecursiveAction and RecursiveTask.
An abstract subclass of the ForkJoinTask class that is called CountedCompleter.
The framework supports two types of tasks: a task that does not yield a result and a task that yields a result.
RecursiveAction class represents a task that does not yield a result.
RecursiveTask class represents a task that yields a result.
A CountedCompleter task may or may not yield a result.
Your class which represents a fork/join task should inherit from RecursiveAction or RecursiveTask and implement for the compute() abstract method.
Typically, the logic inside the compute() method is written similar to the following:
if (Task is small) { Solve the task. } else { Divide the task into subtsaks. Launch the subtasks asynchronously (fork stage). Wait for the subtasks to finish (join stage). Combine the results of all subtasks. }
The following two methods of the ForkJoinTask class provide two important features during a task execution:
Step 1: Declaring a Class to Represent a Task
class MyTask extends RecursiveTask<Long> { // Code for your task goes here }
Step 2: Implementing the compute() Method
class MyTask extends RecursiveTask<Long> { public Long compute() { // Logic for the task goes here } }
Step 3: Creating a Fork/Join Thread Pool
ForkJoinPool pool = new ForkJoinPool();
Step 4: Creating the Fork/Join Task
MyTask task = MyTask();
Step 5: Submitting the Task to the Fork/Join Pool for Execution
long result = pool.invoke(task);
The following code uses fork/join framework compute sum of random integers.
import java.util.ArrayList; import java.util.List; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; class MyTask extends RecursiveTask<Long> { private int count; public MyTask(int count) { this.count = count; }//from ww w .j ava 2 s . c om @Override protected Long compute() { long result = 0; if (this.count <= 0) { return 0L; // We do not have anything to do } if (this.count == 1) { return 10L; } // Multiple numbers. Divide them into many single tasks. Keep the references // of all tasks to call their join() method later List<RecursiveTask<Long>> forks = new ArrayList<>(); for (int i = 0; i < this.count; i++) { MyTask subTask = new MyTask(1); subTask.fork(); // Launch the subtask // Keep the subTask references to combine the results later forks.add(subTask); } // Now wait for all subtasks to finish and combine the result for (RecursiveTask<Long> subTask : forks) { result = result + subTask.join(); } return result; } } public class Main { public static void main(String[] args) { ForkJoinPool pool = new ForkJoinPool(); MyTask task = new MyTask(3); long sum = pool.invoke(task); System.out.println("Sum is " + sum); } }