13537번: 수열과 쿼리 1
문제 링크: 13537번: 수열과 쿼리 1
개요
배열 크기 n과 쿼리의 수 m이 최대 100,000이므로 각 구간에서 k보다 큰 값을 찾을 때마다 구간을 전부 탐색하면 시간복잡도가 O(n×m)이 되어 시간 초과와 맞닥뜨리게 될 것이다.
세그먼트 트리의 각 노드에 자신과 자식 노드의 모든 원소를 정렬한 vector를 저장해서 풀 수 있다. 세그먼트 트리에 대해 잘 모른다면 아래 문제 풀이를 먼저 읽어보자.
[백준 14438번] 수열과 쿼리 17 [C++]
풀이
아래 설명할 두 함수는 node, start, end라는 인자를 갖고 있다. node는 세그먼트 트리상에서 현재 노드의 인덱스이며, start와 end는 arr 배열 상에서의 현재 탐색 범위의 양 끝 인덱스이다. 두 함수는 재귀의 형태를 가지는데, 위에서 말한 세 인자는 재귀를 거치며 segTree[node]가 arr[start]부터 arr[end]까지의 범위를 담당하는 노드임이 유지되도록 변화한다.
세그먼트 트리 생성 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 정렬된 자식노드 vector 세그먼트 트리 만들기
vector<int> makeSegTree(int node, int start, int end) {
if (start == end) {
segTree[node].push_back(arr[start]);
return segTree[node];
}
int mid = (start + end) / 2;
vector<int> leftRst = makeSegTree(node * 2, start, mid);
vector<int> rightRst = makeSegTree(node * 2 + 1, mid + 1, end);
segTree[node] = leftRst;
segTree[node].insert(segTree[node].end(), rightRst.begin(), rightRst.end());
sort(segTree[node].begin(), segTree[node].end());
return segTree[node];
}
보통 세그먼트 트리와는 다르게, 이 문제에서는 세그먼트 트리의 원소의 자료형이 정수 하나가 아니라 vector이다. main 함수에서 이 함수를 makeSegTree(1, 1, n)의 형태로 호출할 것이다. 세그먼트 트리의 루트 노드(1, 전체 범위 담당)부터 시작하여 재귀의 형태로 자식 노드로 뻗어 내려가고, 자식 노드의 값을 이용해 부모 노드의 값을 생성하는 방식이다. 초기 상태의 세그먼트 트리의 모든 노드(vector)는 비어있다.
start == end인 경우, 리프 노드에 도달했다는 뜻이므로 현재 세그먼트 트리의 노드(segTree[node])에 arr 배열의 값(arr[start])을 push하고 재귀를 종료한다.
그 외의 경우, 아직 리프 노드에 도달하지 않았다. 세그먼트 트리에서 인덱스가 node인 노드의 자식 노드의 인덱스는 node * 2, node * 2 + 1이다. 현재 범위를 반으로 나눠 왼쪽 범위의 vector와 오른쪽 범위의 vector를 재귀를 통해 구한 후 현재 노드에 두 자식 노드 vector을 합쳐서 정렬한 vector를 만들어주고 반환한다.
탐색 범위의 k보다 큰 원소를 반환하는 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//탐색 범위의 k보다 큰 원소 개수 얻기
int getCnt(int node, int start, int end, int left, int right, int k) {
//탐색 범위가 얻을 범위에 포함되지 않는 경우
if (left > end || right < start) return 0;
//탐색 범위가 얻을 범위에 완전히 포함되는 경우
if (left <= start && end <= right) {
return segTree[node].end() - upper_bound(segTree[node].begin(), segTree[node].end(), k);
}
//탐색 범위가 얻을 범위에 걸치는 경우
int mid = (start + end) / 2;
int leftRst = getCnt(node * 2, start, mid, left, right, k);
int rightRst = getCnt(node * 2 + 1, mid + 1, end, left, right, k);
return leftRst + rightRst;
}
left와 right는 k보다 큰 원소의 개수를 얻을 범위이며, 재귀를 거치며 변하지 않는다.
탐색 범위(start ~ end)가 얻을 범위(left ~ right)에 포함되지 않는 경우, 이 부분의 반환값은 최종 결과에 영향을 미쳐서는 안 되므로 0을 반환한다.
탐색 범위(start ~ end)가 얻을 범위(left ~ right)에 완전히 포함되는 경우, 세그먼트 트리에서 해당 노드 k보다 큰 원소의 개수를 반환한다. upper_bound를 이용해서 k보다 큰 가장 작은 원소가 처음 등장하는 위치를 세그먼트 트리의 마지막 원소의 다음 위치에서 뺀 값이 구하고자 하는 값이다.
그 외의 경우는 탐색 범위(start ~ end)가 얻을 범위(left ~ right)에 걸치는 경우이다. 위의 두 함수와 마찬가지로 현재 범위를 반으로 나눠 각각 탐색 후 두 반환값의 합을 반환한다.
전체 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <iostream>
#include <vector>
#include <algorithm>
#pragma warning(disable:4996)
#pragma warning(disable:6031)
using namespace std;
#define MAX_N 100000
int arr[MAX_N + 1];
vector<int> segTree[MAX_N * 4 + 4];
vector<int> makeSegTree(int node, int start, int end);
int getCnt(int node, int start, int end, int left, int right, int k);
int main() {
int n, m;
scanf("%d", &n);
for (int i = 1; i <= n; i++) { //수열 입력받기
scanf("%d", &arr[i]);
}
makeSegTree(1, 1, n); //세그먼트 트리 만들기
scanf("%d", &m);
for (int t = 0; t < m; t++) { //쿼리 입력받기
int i, j, k;
scanf("%d %d %d", &i, &j, &k);
printf("%d\n", getCnt(1, 1, n, i, j, k));
}
return 0;
}
// 정렬된 자식노드 vector 세그먼트 트리 만들기
vector<int> makeSegTree(int node, int start, int end) {
if (start == end) {
segTree[node].push_back(arr[start]);
return segTree[node];
}
int mid = (start + end) / 2;
vector<int> leftRst = makeSegTree(node * 2, start, mid);
vector<int> rightRst = makeSegTree(node * 2 + 1, mid + 1, end);
segTree[node] = leftRst;
segTree[node].insert(segTree[node].end(), rightRst.begin(), rightRst.end());
sort(segTree[node].begin(), segTree[node].end());
return segTree[node];
}
//탐색 범위의 k보다 큰 원소 개수 얻기
int getCnt(int node, int start, int end, int left, int right, int k) {
if (left > end || right < start) return 0; //탐색 범위가 얻을 범위에 포함되지 않는 경우
if (left <= start && end <= right) { //탐색 범위가 얻을 범위에 완전히 포함되는 경우
return segTree[node].end() - upper_bound(segTree[node].begin(), segTree[node].end(), k);
}
//탐색 범위가 얻을 범위에 걸치는 경우
int mid = (start + end) / 2;
int leftRst = getCnt(node * 2, start, mid, left, right, k);
int rightRst = getCnt(node * 2 + 1, mid + 1, end, left, right, k);
return leftRst + rightRst;
}