알고리즘/알고리즘(Java)

[분할정복]Closest Points - 가장 가까운 점 사잇거리 구하기

산을좋아한라쯔 2015. 10. 28. 14:10
반응형

문제

이차원 평면상에 위치한 많은 점들이 있다. 두 점 P(x1,y1) Q(x2,y2)의 사이의 거리는 SQRT[(x1-x2)^2 + (y1-y2)^2]로 정의한다면,

모든 점들 사이에서, 가장 가까운 거리를 찾아내시오.


제약조건

각 문제에 대해 답을 찾아내는 시간은 0.5초(500ms 이내)일 것


입력

입력문제는 텍스트파일로 주어지고, 파일내 형식은 다음과 같음

 - 첫번 째 행에, 전체 문제 갯수가 나오고

 - 두번째 행부터 다음과 같은 형식

   . 점의 수가 나오고

   . 그 다음 줄에 공백으로 구분되어 x,y값이 나옴


예)

3

12

2 7 4 13 5 8 10 5 14 9 15 5 17 7 19 10 22 7 25 10 29 14 30 2 

3000

2248 2935 2815 3 315 ...

100000

55722 33770 3401 10325...


문제 소스

import java.io.FileInputStream;

import java.util.Scanner;

 

public class Test_ClosestDistance {

         public static void main(String[] args) throws Exception {

                  int T, N;

                  double answer;

 

                  System.setIn(new FileInputStream("res/ClosestDistance.in"));

 

                  Scanner sc = new Scanner(System.in);

 

                  long t1, t2;

                  T = sc.nextInt();

                  for (int test_case = 1; test_case <= T; test_case++) {

                           N = sc.nextInt();

                           int[][] p = new int[N][2];

                           for (int i = 0; i < N; i++) {

                                   p[i][0] = sc.nextInt();

                                   p[i][1] = sc.nextInt();

                           }

 

                           t1 = System.currentTimeMillis();

                           answer = closestDistance(p);

                           t2 = System.currentTimeMillis();

                           System.out.println("TestCase # " + test_case + ":" +

answer + " (" + (t2 - t1) + " milliseconds)");      

                  }

                  sc.close();

         }                

 

         private static double closestDistance(int[][] p) {

                 

                  return 0;

         }

}


-------------------------------------------------------------

****************************************************************************

-------------------------------------------------------------


풀이

이 문제는 'Closest Set of Points'라고 불리는, 분할정복 알고리즘에 있어서 유명한 문제이고, 분할정복관련 문제 중 어려운 문제에 속한다.


n개의 점이 주어졌을 때, 두 점 사잇거리의 경우의 수는 n2/2 

    첫번 째 점이 선택할 수 있는 나머지 점들 갯수: n-1

    두번 째 점이 선택할 수 있는 나머지 점들 갯수 : n-2

    ..

    n-1번째 점이 선택할 수 있는 나머지 점들 갯수: 1

   ------------------------------------------

                                                     총 n2/2 개


따라서, 그냥 직관적으로 풀면 O(n2)내에 풀 수 있다.

                  double min=Double.POSITIVE_INFINITY;

                  double d;

                  for(int i=0;i<n;i++){

                           for(int j=i+1;j<n;j++){

                                   d = dist(points[i],points[j]);                               

                                   if(d<min)min=d;

                           }

                  }

                  System.out.println("closest distance="+min);



근데 문제는, n의 갯수가 많아지면 주어진 시간 안에 답이 안나온다는 것.


분할정복으로 풀면 O(nlog(n))에 문제를 풀 수 있다.

(아래 설명할 것은 사실 O(n log(n)2) 이다. 그러나, 현실적으로는 O(nlog(n))이 되게 구현한 것보다 빠르다. 

  nlog(n)이 되게 하기위해, y축으로 소팅된 배열을 관리하는 오버헤드가 너무 커서 그런 듯하다.)


기본 아이디어는,

  - x축을 기준으로해서, 분할을 recursive하게 하고,

  - 왼쪽편에서 가장 작은 거리 d1, 오른편에서 가장 작은 거리 d2, 양쪽 사이에서 가장 작은 거리 d3중에서,

    가장 작은 거리를 min distance로 구함


d1과 d2는 d3가 구해지면 재귀적으로 자동 구해지겠는데, 문제는 d3를 구하는 것을 O(n)타임에 하는 것이, 알고리즘의 핵심이다.

그냥 직관적으로 구해버리면, 왼쪽편과 오른편에 있는 것들간의 모든 조합이 되어버려서 O(n2)이 되어 버린다.

키 아이디어는, 이미 구해져 있는 d1과 d2를 이용해서, 비교하는 갯수를 획기적으로 줄여버리는 것.

아래 그림을 보자.


x축을 기준으로 해서 봤을 때, 현재 start와 end사이에 점들 있고, middle을 기준으로해서 왼쪽편에서 가장 작은 사잇거리가 d1이고, 오른편 제일 작은 사잇거리는 d2이다.  (d1과 d2를 어떻게 구해지는 지는 일단 생각을 말자. 이미 구해졌다고 가정하자. 사실은, 지금 과정에서 구하려는 min(d1,d2,d3)를 통해서 재귀적으로 d1과 d2가 전과정(before)에서 재귀적으로 구해진 것)

이 상태에서, 왼쪽편만의, 그리고 오른편만에서의 가장 작은 사잇거리는 알겠는데, 이제 왼편에 있는 점들과 오른편에 있는 점들과의 사잇거리중에서 d1,d2보다 작은 것이 있는 지를 구해야 한다.


아래 그림처럼, d범위안에서만 확인하면 된다. d는 d1과 d2중 작은 값이다. d=min(d1,d2)


사실 당연한 거다. 가장 작은 사잇거리를 찾는 것이고, 이미 왼쪽편과 오른편을 통틀어 가장 작은 값인 d가 구해졌기에, 왼편과 오른편 점들과의 연결 중에서 d보다도 더 작은 것이 있는 가를 찾는 것이기에, m을 기준으로해서 왼편으로 d만큼, 오른편으로 d만큼 거리 안에서만 찾으면 된다.


 bPoints = 축을 기준으로 (m-d)~(m+d) 사이에 위치한 점들 (band Points)


이 bPoints가 y로 정렬되어 있다면, 이 bPoints에 있는 점들 모두를 비교할 필요 없다. 즉, 더 비교횟수를 줄일 수 있다.

즉, bPoints의 가장 밑에 있는 점에서 출발해서, y축 사잇거리가 d보다 작은 것들끼리만 비교하면 된다. 


위 그림처럼, 현재 점을 기준으로 해서, y축 기준 사잇거리가 d이내의 점까지만 비교하고, 넘어가면 루프를 벗어나서 다음 점으로 이동하면 된다.

                  double d;

                  for(int i=0;i<size;i++){

                           //for(int j=i+1;j<size && j<=(i+6);j++){

                           for(int j=i+1;

j<size && ( (bPoints[j].y-bPoints[i].y) <minDist );j++){

                                   d = dist(bPoints[i], bPoints[j]);

                                   if(d<minDist) {

                                            minDist = d;                                        

                                   }

                           }

                  }



이제, 처음부터 Full로 코드를 짜보자.

가장 상위에 있는 메서드로 closestDistance()를 만들자. 이 메서드는, x축과 y축 값으로 구성된 2차원 배열을 입력받아서, 가장 짧은 사잇거리를 출력해내는 메서드.

   private static double closestDistance(int[][] p)


이 메서드에선 크게 3가지 작업 실시.

1. 이차원 배열을 Point객체 배열로 전환: p -> xPoints

   (그냥 2차원 배열을 사용해도 되지만, 계산과 코드를 깔끔하게 하기 위해 Point라는 클래스 작성해서 사용하자)

2. xPoints를 x축 기준으로 소팅

3. 재귀적으로 호출하면서 최소사잇거리를 계산할 closest() 메서드 호출


         private static double closestDistance(int[][] p) {

                  if (p == null) {

                           return -1;

                  }

 

                  int n = p.length;

                  if (n == 1)

                           return -1;

 

                  // 1. Generate Points with given array: O(n)

                  Point[] points = new Point[n];

                  for (int i = 0; i < n; i++) {

                           points[i] = new Point(i,p[i][0], p[i][1]);

                  }

                 

                  // 2. Sort by x-coordinate: O(nlog(n))

                  Point[] xPoints = new Point[n];

                  System.arraycopy(points, 0, xPoints, 0, n);

                  Arrays.sort(xPoints, new Comparator<Point>() {

                           @Override

                           public int compare(Point p1, Point p2) {

                                   if (p1.x > p2.x) {

                                            return 1;

                                   } else if (p1.x < p2.x) {

                                            return -1;

                                   }

                                   return 0;

                           }

                  });

 

                  // 3. calculate closest distance: nlog(n)^2         

                  double d = closest(xPoints, 0, n - 1);

                 

                  return d;

         }


이제, closest() 메서드를 보자.

private static double closest(Point[] xPoints, int s, int e


이 메서드는 전형적인 '분할정복' 알고리즘 처럼, 다음과 같은 순서로 처리.

  1)Base: 가장 작은 단위일 때 벗어나는 조건 처리

  2)Divide: 재귀적으로 작은 단위로 쪼겜 (Left, Right)

  3)Merge: O(n) 실행시간 내에서, 문제 처리


         private static double closest(Point[] xPoints, int s, int e) {

                  // 1. base

                  if ((e - s) == 0) {

                           return Double.POSITIVE_INFINITY;

                  }

 

                  // 2. divide : log(n)

                  int m = (s + e) / 2;

 

                  double d1 = closest(xPointss, m);

                  double d2 = closest(xPointsm + 1, e);

                  double d = (d1 < d2) ? d1 : d2;

 

                  // 3. merge

                  // (m-d) ~ (m+d) 사이 band 있는 점들 = bPoints

                  Point[] bPoints = new Point[e - s + 1];

                  Point midPoint = xPoints[m];

                  int k=0;

                  for (int i = s; i <=e ; i++) {

                           if (Math.abs(xPoints[i].x - midPoint.x) <= d) {                                

                                   bPoints[k++] = xPoints[i];

                           }

                  }

 

                  double d3 = closestInBand(bPoints, k, d );

 

                  return d3;

         }


1. base 

xPoints에서 처리할 점의 갯수가 1일 때, 굉장히 큰 값을 리턴하게 되어 있다. 즉, 재귀적으로 호출되기에 log(n)의 속도로 잘게 쪼게지다가, 결국 처리할 점의 갯수가 1개일 때 굉장히 큰 값(Double.POSITIVE_INFINITY) 리턴. 

--> 이러한 점이 1개짜리인 블럭과, 또 다른 점이 1개짜리인 블럭끼리 재귀적으로 호출될 것이고, 이 때 d1=POSITIVE_INFINITY, d2=POSITIVE_INFINITY가 되고, 따라서, 각 블럭에 있는 점 사잇거리인 d3가 최소값이 될 것이다. (요 부분은 코드 아랫부분인 3.merge에서 계산되는 것임)


2. Divide

'분할정복'의 전형이다. x축을 기준으로 절반씩 나눠서 호출된다.


3. Merge

이 부분이 알고리즘의 핵심이고, 이해하기 힘든 부분이다.

이 글의 윗 부분에 그림으로 설명했듯이, x축의 m값을 기준으로 (m-d)~(m+d)값에 대해서만 조사를 하면 되기에, 이 부분에 있는 점들을 모아 놓은 bPoints 배열을 만든다. 이 부분에 있는 점들이 몇 개나 될지는 아직 모르기에 최댓값으로 잡아놓고, (최댓값= e-s+1 )

루프를 돌면서 크기를 알아낸다. (k)

                  Point[] bPoints = new Point[e - s + 1];

                  Point midPoint = xPoints[m];

                  int k=0;

                  for (int i = si <=e ; i++) {

                           if (Math.abs(xPoints[i].x - midPoint.x) <= d) {                                

                                   bPoints[k++] = xPoints[i];

                           }

                  }


이제, 이 밴드내에서 d보다 작은 사잇거리가 있는지 찾아내는 작업을 하는, closestInBand() 메서드를 만들자.

  private static double closestInBand(Point[] bPoints, int size, double minDist)


이 메서드는, 먼저 bPoints를 y축을 준으로 소팅하고, 가장 밑에 있는 점부터 시작해서 위쪽으로 가면서, 가장 사잇거리가 짧은 점을 찾아낸다.

여기서 핵심은, 두 점간의 y축 거리가 d를 넘어서는 경우, 다음 점으로 이동해 버린게 한다는 것. y축 거리가 d보다 크면서, 사잇거리가 d보다 작은 점은 있을 수 없기 때문. 

  for(int j=i+1;j<size && ( (bPoints[j].y-bPoints[i].y) <minDist );j++)


(이 문제에 대한 풀이를 하는 알고리즘 강의를 보면, 6개 정도만 비교하면 된다고 한다. 그것은, 실제 최악의 경우도 이 정도 밖에 비교안한다는 것이고, 실제로 6개를 비교하게 프로그래밍 할 필요는 없을 것이다. y 길이가 d보다 커지게 되면 바로 벗어나 버리는 것이 현명.) 

         private static double closestInBand(Point[] bPoints, int size, double minDist) {

                  //sort

                  Arrays.sort(bPoints, 0, size, new YComparator());

                 

                  double d;

                  for(int i=0;i<size;i++){

                           //for(int j=i+1;j<size && j<=(i+6);j++){

                           for(int j=i+1;j<size && ( (bPoints[j].y-bPoints[i].y) <minDist );j++){

                                   d = dist(bPoints[i], bPoints[j]);

                                   if(d<minDist) {

                                            minDist = d;                                        

                                   }

                           }

                  }

                 

                  return minDist;

         }


입력으로 주언진 파일에 대한 답은 다음과 같다.

TestCase # 1:2.8284271247461903 (2 milliseconds)

TestCase # 2:1.4142135623730951 (35 milliseconds)

TestCase # 3:1.0 (367 milliseconds)


소스

package divideandconquer;

 

import java.io.FileInputStream;

import java.util.Arrays;

import java.util.Comparator;

import java.util.Scanner;

 

public class ClosestDistance {    

         public static void main(String[] args) throws Exception {

                  int T, N;

                  double answer;

 

                  System.setIn(new FileInputStream("res/ClosestDistance.in"));

 

                  Scanner sc = new Scanner(System.in);

 

                  long t1, t2;

                  T = sc.nextInt();

                  for (int test_case = 1; test_case <= T; test_case++) {

                           N = sc.nextInt();

                           int[][] p = new int[N][2];

                           for (int i = 0; i < N; i++) {

                                   p[i][0] = sc.nextInt();

                                   p[i][1] = sc.nextInt();

                           }

 

                           t1 = System.currentTimeMillis();

                           answer = closestDistance(p);

                           t2 = System.currentTimeMillis();

                          System.out.println("TestCase # " + test_case + ":" + answer + " (" + (t2 - t1) + " milliseconds)");

                          

                           /*

                           System.out.println("");

                           t1 = System.currentTimeMillis();

                           printAllDistance(p);

                           t2 = System.currentTimeMillis();

                          System.out.println("TestCase # " + test_case + ":" + answer + " (" + (t2 - t1) + " milliseconds)");

                           */

                  }

                  sc.close();

         }

        

         private static void printAllDistance(int[][] p){

                  int n = p.length;

                 

                  Point[] points = new Point[n];

                  for (int i = 0; i < n; i++) {

                           points[i] = new Point(i,p[i][0], p[i][1]);

                  }

                 

                  double min=Double.POSITIVE_INFINITY;

                  double d;

                  for(int i=0;i<n;i++){

                           for(int j=i+1;j<n;j++){

                                   d = dist(points[i],points[j]);                               

                                   if(d<min)min=d;

                           }

                  }

                  System.out.println("closest distance="+min);

         }

 

         private static double closestDistance(int[][] p) {

                  if (p == null) {

                           return -1;

                  }

 

                  int n = p.length;

                  if (n == 1)

                           return -1;

 

                  // 1. Generate Points with given array: O(n)

                  Point[] points = new Point[n];

                  for (int i = 0; i < n; i++) {

                           points[i] = new Point(i,p[i][0], p[i][1]);

                  }

                 

                  // 2. Sort by x-coordinate: O(nlog(n))

                  Point[] xPoints = new Point[n];

                  System.arraycopy(points, 0, xPoints, 0, n);

                  Arrays.sort(xPoints, new Comparator<Point>() {

                           @Override

                           public int compare(Point p1, Point p2) {

                                   if (p1.x > p2.x) {

                                            return 1;

                                   } else if (p1.x < p2.x) {

                                            return -1;

                                   }

                                   return 0;

                           }

                  });

 

                  // 3. calculate closest distance: nlog(n)^2         

                  double d = closest(xPoints, 0, n - 1);

                 

                  return d;

         }

 

         private static double closest(Point[] xPoints, int s, int e) {

                  // 1. base

                  if ((e - s) == 0) {

                           return Double.POSITIVE_INFINITY;

                  }

 

                  // 2. divide : log(n)

                  int m = (s + e) / 2;

 

                  double d1 = closest(xPointss, m);

                  double d2 = closest(xPointsm + 1, e);

                  double d = (d1 < d2) ? d1 : d2;

 

                  // 3. merge

                  // (m-d) ~ (m+d) 사이 band 있는 점들 = bPoints

                  Point[] bPoints = new Point[e - s + 1];

                  Point midPoint = xPoints[m];

                  int k=0;

                  for (int i = s; i <=e ; i++) {

                           if (Math.abs(xPoints[i].x - midPoint.x) <= d) {                                

                                   bPoints[k++] = xPoints[i];

                           }

                  }

 

                  double d3 = closestInBand(bPoints, k, d );

 

                  return d3;

         }

 

         private static double closestInBand(Point[] bPoints, int size, double minDist) {

                  //sort

                  Arrays.sort(bPoints, 0, size, new YComparator());

                 

                  double d;

                  for(int i=0;i<size;i++){

                           //for(int j=i+1;j<size && j<=(i+6);j++){

                           for(int j=i+1;j<size && ( (bPoints[j].y-bPoints[i].y) <minDist );j++){

                                   d = dist(bPoints[i], bPoints[j]);

                                   if(d<minDist) {

                                            minDist = d;                                        

                                   }

                           }

                  }

                 

                  return minDist;

         }

 

 

        

         private static double dist(Point a, Point b){

                  return a.dist(b);

         }

        

         static class Point {

                  int idx;

                  int x, y;

                 

                  public Point(int idx,int x, int y) {

                           this.idx=idx;

                           this.x = x;

                           this.y = y;

                  }

 

                  public double dist(Point b) {

                           // SQRT[(a.x - b.x)^2 + (a.y - b.y)^2]

                           return Math.sqrt(Math.pow((x - b.x), 2) + Math.pow(y - b.y, 2));

                          

                  }

         }

        

         static class YComparator implements Comparator<Point>{

                  @Override

                  public int compare(Point p1, Point p2) {

                           if (p1.y > p2.y) {

                                   return 1;

                           } else if (p1.y < p2.y) {

                                   return -1;

                           }

                           return 0;

                  }

         }

 

}





-끝- 



반응형