본문 바로가기
Language/Java

[Java] 쓰레드 8 - fork & join 프레임 워크

by 계범 2022. 3. 14.

fork & join 프레임워크

이 프레임워크는 멀티쓰레드 프로그래밍을 구현하기 위해 하나의 작업을 작은 단위로 나눠서 여러 쓰레드가 동시에 처리하는 것을 쉽게 만들어준다.

 

수행할 작업에 따라 아래의 두 클래스 중에서 하나를 상속받아 구현한다.

RecursiveAction : 반환값이 없는 작업을 구현할 때 사용
RecursiveTask : 반환값이 있는 작업을 구현할 때 사용

 

두 클래스 모두 compute()라는 추상 메서드를 가지고 있고, 이 추상 메서드를 구현하면 된다.

 

1) compute()에 작업을 수행하기 위한 코드 넣기

2) 쓰레드풀과 수행할 작업 생성

3) invoke()로 작업 시작.

 

쓰레드풀은 지정된 수의 쓰레드를 생성해서 미리 만들어 놓고 반복해서 재사용할 수 있게 한다.

쓰레드를 반복해서 생성하지 않아도 된다는 장점과 너무 많은 쓰레드가 생성되어 성능이 저하되는 것을 막아준다는 장점이 있다.

 

쓰레드 풀은 쓰레드가 수행해야하는 작업이 담긴 큐를 제공하며, 각 쓰레드는 자신의 작업 큐에 담긴 작업을 순서대로 처리한다.

 

compute()구현

compute()를 구현할 때는 수행할 작업 외에도 작업을 어떻게 나눌 것인지도 알려줘야한다.

 

public Long compute(){
	long size = to - from +1; // from <= i <= to
    
    if(size <= 5){ // 더할 숫자가 5개 이하면
    	return sum(); // 숫자의 합을 반환. sum()은 from부터 to까지의 수를 더해서 반환
    }
    
    // 범위를 반으로 나눠서 두 개의 작업을 생성
    long half = (from + to) / 2;
    
    SumTask leftSum = new SumTask(from, half);
    SumTask rightSum = new SumTask(half+1, to);
    
    leftSum.fork(); // 작업(leftSum)을 작업 큐에 넣는다.
    
    return rightSum.compute() + leftSum.join();
}

 

여기서는 지정된 범위를 절반으로 나누어서 나눠진 범위의 합을 계산하기 위한 새로운 SumTask를 생성하는데,

이 과정은 작업이 더 이상 나눠질 수 없을 때까지,

size의 값이 5보다 작거나 같을 때까지 반복된다.

 

1부터 8까지의 숫자를 더하는 과정의 그림.

 

다른 쓰레드의 작업 훔쳐오기

fork()가 호출되어 작업 큐에 추가된 작업 역시, compute()에 의해 더 이상 나눌 수 없을때까지 반복해서 나뉘고,

자신의 작업 큐가 비어있는 쓰레드는 다른 쓰레드의 작업 큐에서 작업을 가져와서 수행한다.

이것을 작업 훔쳐오기(work stealing)라고 하며, 이 과정은 모두 쓰레드풀에 의해 자동적으로 이루어진다.

이 과정을 통해, 여러 쓰레드가 골고루 작업을 나누어 처리하게 된다.

 

fork() 와 join()

fork() : 해당 작업을 쓰레드 풀의 작업 큐에 넣는다. 비동기 메서드
join() : 해당 작업의 수행이 끝날 때까지 기다렸다가, 수행이 끝나면 그 결과를 반환한다. 동기 메서드

 

비동기 메서드는 일반적인 메서드와 달리 메서드를 호출만 할 뿐, 그 결과를 기다리지 않는다.

 

작업을 나누고 합치는데 걸리는 시간이 있기 때문에 무조건 멀티쓰레드로 처리하는 것이 좋은게 아니다.

테스트해보고 이득이 있을 때만 멀티쓰레드로 처리하자.

 

예시

import java.util.concurrent.*;

class ForkJoinEx1 {
	static final ForkJoinPool pool = new ForkJoinPool();  // 쓰레드풀을 생성

	public static void main(String[] args) {
		long from = 1L;
		long to   = 100_000_000L;

		SumTask task = new SumTask(from, to);

		long start = System.currentTimeMillis(); // 시작시간 초기화
		Long result = pool.invoke(task);

		System.out.println("Elapsed time(4 Core):"+(System.currentTimeMillis()-start));
		System.out.printf("sum of %d~%d=%d%n", from, to, result);
		System.out.println();

		result = 0L;
		start = System.currentTimeMillis(); // 시작시간 초기화
		for(long i=from;i<=to;i++)
			result += i;

		System.out.println("Elapsed time(1 Core):"+(System.currentTimeMillis()-start));
		System.out.printf("sum of %d~%d=%d%n", from, to, result);
	} // main의 끝
}

class SumTask extends RecursiveTask<Long> {
	long from;
	long to;

	SumTask(long from, long to) {
		this.from = from;
		this.to    = to;
	}

	public Long compute() {
		long size = to - from;

		if(size <= 5)     // 더할 숫자가 5개 이하면
			return sum(); // 숫자의 합을 반환

		long half = (from+to)/2;

		// 범위를 반으로 나눠서 두 개의 작업을 생성
		SumTask leftSum  = new SumTask(from, half);
		SumTask rightSum = new SumTask(half+1, to);

		leftSum.fork();

		return rightSum.compute() + leftSum.join();
	}

	long sum() { // from~to의 모든 숫자를 더한 결과를 반환
		long tmp = 0L; 

		for(long i=from;i<=to;i++)
			tmp += i;

		return tmp;
	}
}
//결과
Elapsed time(4 Core):1407
sum of 1~100000000=5000000050000000

Elapsed time(1 Core):1266
sum of 1~100000000=5000000050000000

 

참조

'Java의 정석' 책

댓글