알고리즘/알고리즘(Java)

[분할정보]Counting Inversions

산을좋아한라쯔 2015. 10. 27. 13:37
반응형

문제

배열 A는 1,2,3,...n 의 수가 무작위 순서로 들어 있다. 이 수들에서 두개의 무작위 수를 생각했을 때, 그 순서 대비 크기가 역전되어 있는 두 수의 쌍이 몇개가 되는 지를 구하시오. 

 Number of inversions = Number of pairs(i,j)  when i<j and A[i] > A[j]


예를 들어, A={2,3,6,4,1,7}일 때, 크기가 역전된 쌍은, (2,1), (3,1), (6,4), (6,1), (4,1)

따라서 Number of inversions = 5


제약조건

각 문제에 대해 답을 찾아내는 시간은 0.01초(10ms 이내)일 것 

--> O(n2)이 되게 알고리즘 짜지 말라는 것.


입력

입력문제는 텍스트파일로 주어지고, 파일내 형식은 다음과 같음

 - 첫번 째 행에, 전체 문제 갯수가 나오고

 - 두번째 행부터 다음과 같은 형식

   . 배열의 수가 나오고

   . 그 다음 줄에 공백으로 구분된 배열값 나옴


예)

4

8

6 5 3 1 8 7 2 4 

6

2 3 6 4 1 7

7

2 3 6 4 1 7 0

3000

222 464 2405 2435 ...


----------------------------------------------------------------------------------

*******************************************************************************************************

----------------------------------------------------------------------------------

풀이

직관적으로 풀면 대략 O(n2)되는 알고리즘이 나와서, 입력 배열이 수천개를 넘어서면, 제약조건인 10ms를 못 맞춘다.

(물론 엄청 빠른 CPU를 쓰면 가능할 지도...^^)

  - 첫번 째 수를 선택하고, 나머지에서 첫번째 수보다 작은 값 갯수 구한다. -> (n-1)회 비교

  - 두번 째 수를 선택하고, 나머지 (n-2)개에서 작은 값을 구한다. -> (n-2)회 비교

  - ...

  

수 천개의 입력값에 대해서도 10ms 이하 정도의 수행속도를 보이려면, nlogn 정도의 알고리즘이 필요하다.

어떻게 구할까?


앞에서 봤던, 분할정복에 의한 MergeSort를 응용하면 된다. 

(앞 장의 MergeSort를 안 봤다면, 이 부분을 먼저 봐야 이해가 된다.)


원리는 이렇다.

  - 만약 두 개의 소팅된 블럭이 있다면, 두 블럭사이의 Inversion Pair의 갯수는 O(n)에 구할 수 있다.

아래 그림으로 생각해 보자.




두번 째 그림을 보면, 오른편에 있는 '2'가 선택이 되었다. 즉, 왼쪽편 블럭에 있는 '3'보다 작기 때문이다. 그렇다면, 왼쪽 블럭에 있는 '3'과 그 다음에 있는 값들은 전부 '2'보다 크다는 얘기가 된다. (소팅되어 있으니깐)

따라서, MergeSort의 merge() 과정에서, 오른편 값이 선택될 때, 왼쪽편 블럭의 인덱스인 i 이후에 얼마나 값이 있는가가, 해당 오른편 값에 대한 Inversion 쌍 수량이 된다.

이러한 과정을 merge가 종료될 때까지 하게되면, 두 블럭간의 Inversion Count를 구할 수 있게 된다. (블럭내에서의 Inversion은 재귀호출을 통해서.)


이제 정리해보면 다음과 같다.

 - leftCount = count(왼쪽편 블럭)

 - rightCount = count(오른편 블럭)

 - mergeCount = merge(왼쪽편, 오른편)

TotalCount = leftCount + rightCount + mergeCount



소스

package divideandconquer;

 

import java.io.FileInputStream;

import java.util.Scanner;

        

public class CountingInversions { 

        

         public static void main(String[] args)throws Exception {

                  int T, N;

                  int answer;

 

                  System.setIn(new FileInputStream("res/CountingInversions.in"));

 

                  Scanner sc = new Scanner(System.in);

                 

                  long t1,t2;

                  T = sc.nextInt();

                  for (int test_case = 1; test_case <= T; test_case++) {

                           N = sc.nextInt();

                           int[] a = new int[N];

                           for(int i=0;i<N;i++){

                                   a[i] = sc.nextInt();

                           }

                          

                           t1 = System.currentTimeMillis();

                           answer = countingInversions(a);

                           t2 = System.currentTimeMillis();

                           System.out.println("TestCase # "+test_case+":"+answer +

" ("+(t2-t1)+" milliseconds)");

                  }

         }

 

         private static int countingInversions(int[] a) {

                  int n = a.length;

                  int[] buf = new int[n];           

                  int cnt = count(a,0,n-1,buf);

                 

                  return cnt;

         }

        

         private static int count(int[] a, int s, int e, int[] buf){

                  if((e-s)<1) {

                           return 0;

                  }       

                 

                  int m = (s+e) / 2;

                  int leftCount = count(a,s,m,buf);

                  int rightCount = count(a,m+1,e,buf);

                  int mergeCount = merge(a,s,m,e,buf);

                  System.arraycopy(buf, s, a, s, (e-s)+1);

                 

                  return leftCount + rightCount + mergeCount;

         }

        

         private static int merge(int[]a,int s, int m, int e, int[] buf){

                  //System.out.println("merge("+s+" "+e+")");

                  int left=s;

                  int right=m+1;

                  int count=0;

                  for(int k=s;k<=e;k++){

                           if(left<=m && ( (right>e) || (a[left] <= a[right]) ) ){

                                   buf[k] = a[left++];                                 

                           }else{

                                   buf[k] = a[right++];

                                   count = count + (m-left+1);

                           }

                  }

                  return count;

         }

        

        

}

 

-끝-


반응형