본문 바로가기
알고리즘/이론

[알고리즘] 병합 정렬 (Merge Sort)

by 쿠브 2022. 5. 2.

병합 정렬이란

배열을 앞부분과 뒷부분 두 그룹으로 나누어 각각 정렬한 후, 다시 병합하는 작업을 반복하는 알고리즘입니다. 분할 정복(divide and conquer) 방법을 사용하여 문제를 해결합니다.

 

병합 정렬 예

[5, 3, 2, 1, 6, 8, 7, 4]     #정렬하려는 배열

 

1. divide

1단계: [5] [3]을 병합하면 [3, 5]

             [2] [1]을 병합하면 [1, 2]

             [6] [8]을 병합하면 [6, 8]

             [7] [4]을 병합하면 [4, 7]

2단계: [3, 5] 과 [1, 2]을 병합하면 [1, 2, 3, 5]

             [6, 8] 와 [4, 7]을 병합하면 [4, 6, 7, 8]

 

A : [1, 2, 3, 5]     #정렬된 배열 

B : [4, 6, 7, 8]     #정렬된 배열 

C : []                #정렬할 원소를 저장할 빈 배열

 

2. conquer

1단계:  [1, 2, 3, 5]

              [4, 6, 7, 8]    #1 < 4 이므로 1을 C 에 넣습니다.

              [1]

2단계 : [1, 2, 3, 5]

              [4, 6, 7, 8]    #2 < 4 이므로 2를 C 에 넣습니다.

              [1, 2]

3단계 : [1, 2, 3, 5]

              [4, 6, 7, 8]    #3 < 4 이므로 3을 C 에 넣습니다.

              [1, 2, 3]

3단계 : [1, 2, 3, 5]

              [4, 6, 7, 8]    #5 > 4 이므로 4을 C 에 넣습니다.

              [1, 2, 3, 4]

3단계 : [1, 2, 3, 5]

              [4, 6, 7, 8]    #5 < 6 이므로 5을 C 에 넣습니다.

              [1, 2, 3, 4, 5]

 

A의 모든 원소는 C로 옮겼습니다.

B에 남은 원소 [6, 7, 8]은 어떡할까요?

하나씩 C에 추가하면 됩니다.

C : [1, 2, 3, 4, 5, 6, 7, 8]

 

병합 정렬 python 코드

array = [5, 3, 2, 1, 6, 8, 7, 4]


def merge_sort(array):
    if len(array) <= 1:
        return array
    mid = len(array) // 2
    left_array = array[:mid]
    right_array = array[mid:]
    print(array)
    print('left_arary', left_array)
    print('right_arary', right_array)
    return merge(merge_sort(left_array), merge_sort(right_array))


def merge(array1, array2):
    result = []
    array1_index = 0
    array2_index = 0
    while array1_index < len(array1) and array2_index < len(array2):
        if array1[array1_index] < array2[array2_index]:
            result.append(array1[array1_index])
            array1_index += 1
        else:
            result.append(array2[array2_index])
            array2_index += 1

    if array1_index == len(array1):
        while array2_index < len(array2):
            result.append(array2[array2_index])
            array2_index += 1

    if array2_index == len(array2):
        while array1_index < len(array1):
            result.append(array1[array1_index])
            array1_index += 1

    return result


print(merge_sort(array))

print("정답 = [-7, -1, 5, 6, 9, 10, 11, 40] / 현재 풀이 값 = ", merge_sort([-7, -1, 9, 40, 5, 6, 10, 11]))
print("정답 = [-1, 2, 3, 5, 10, 40, 78, 100] / 현재 풀이 값 = ", merge_sort([-1, 2, 3, 5, 40, 10, 78, 100]))
print("정답 = [-1, -1, 0, 1, 6, 9, 10] / 현재 풀이 값 = ", merge_sort([-1, -1, 0, 1, 6, 9, 10]))

 

병합 정렬의 시간 복잡도

모든 단계에서 N만큼 비교를 합니다.

크기가 N → N/2 → N/2^2 → N/2^3 → .... → 1이 되는 순간이 올텐데  

log_2N 번 반복하게 되면 1이 됩니다. (k=log_2N)

이걸 수식으로 나타내면 N만큼의 연산을 logN번 반복한다고 해서

시간 복잡도는 O(Nlog_2N) = O(NlogN)이 됩니다.

댓글