Home [백준 14621번] 나만 안되는 연애 [C++]
Post
Cancel

[백준 14621번] 나만 안되는 연애 [C++]

14621번: 나만 안되는 연애

문제 링크: 14621번: 나만 안되는 연애


개요

주어진 그래프에서 노드들을 연결하는 최단 거리 트리의 총 거리를 구하는 문제이다. 최소 스패닝 트리 알고리즘 중 크루스칼 알고리즘을 이용해 풀이하겠다.
최소 스패닝 트리에 대해 [알고리즘] 최소 신장 트리(MST, Minimum Spanning Tree)란 포스팅을 참고하면 좋겠다.


풀이

getRoot 함수

1
2
3
4
5
6
int getRoot(int num) {
	if (num != prnt[num]) {
		prnt[num] = getRoot(prnt[num]);
	}
	return prnt[num];
}

학교 a와 b가 있을 때, prnt[a]가 b인 경우엔 a와 b가 직접 연결되어 있는 경우이고, getRoot(a)와 getRoot(b)가 같은 경우엔 두 대학은 연결된 경로가 있는 경우이다. 두 경우는 동시에 발생할 수 있다.
매번 부모(직접 연결된 노드) 타고 올라가며 루트 노드를 확인하여 비교하면 시간이 오래 걸리기 때문에 어떤 노드의 루트 노드를 얻음과 동시에 해당 노드부터 루트 노드까지 존재하는 모든 노드(n)의 prnt[n]을 루트 노드로 갱신시켜주어 이후 탐색 시에 시간을 단축시키는 역할이다.
루트 노드(부모가 자기 자신인 노드)에 도달할 때까지 재귀를 반복하고, 자식 노드를 루트 노드로 갱신해준 후 반환한다.

경로 길이 계산 이전

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
    int n, m, cnt = 0, sumWeight = 0;
    scanf("%d %d", &n, &m); getchar();

    for (int i = 1; i <= n; i++) {
        prnt[i] = i;

        char c;
        scanf("%c", &c); getchar();
        if (c == 'W') gender[i] = 1;
    }

    for (int i = 0; i < m; i++) {
        int u, v, w;
        scanf("%d %d %d", &u, &v, &w);

        tree[i].node1 = u, tree[i].node2 = v;
        tree[i].weight = w;
    }

    qsort(tree, m, sizeof(Edge), compare);

대학의 인덱스는 1부터 n까지의 정수이며, 대학 하나를 노드 하나로 취급할 것이다.
초기 상태에서 모든 대학은 연결되어 있기 않기 때문에 모든 노드의 부모를 자기 자신의 번호로 초기화 시켜준다.
또한 남초 대학교와 여초 대학교가 구분되어 있기 때문에 gender 배열의 해당 대학 인덱스에 남초 대학이면 0, 여초 대학이면 1이 들어가있도록 해준다.
이후 구조체로 만들어둔 간선 배열에 간선을 입력받고, 도로의 거리에 대해 오름차순으로 정렬한다.

경로 길이 계산

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
    for (int i = 0; i < m && cnt < n; i++) {
		int node1 = tree[i].node1;
		int node2 = tree[i].node2;
		int rootNode1 = getRoot(node1);
		int rootNode2 = getRoot(node2);
		if (rootNode1 != rootNode2 && gender[node1] != gender[node2]) {
			sumWeight += tree[i].weight;
			prnt[node1] = rootNode1;
			prnt[rootNode2] = rootNode1;
			cnt++;
		}
	}

	if (cnt < n - 1) printf("-1");
	else printf("%d", sumWeight);

정렬된 간선을 순서대로 탐색하며 해당 간선이 연결하는 두 대학이 둘 다 남초거나 둘 다 여초인 경우(gender 배열 값이 같은 경우), 그리고 이미 연결된 대학(조상 노드가 같은 경우)인 경우는 건너뛴다.
그렇지 않은 경우는 두 대학을 같은 그룹으로 만들고, 누적 거리에 해당 간선 거리를 더한다.
n개 대학을 전부 연결하기 위해서는 최소 (n-1)개의 간선이 필요하다. 간선 배열을 전부 탐색한 이후 시점에서 연결된 간선의 수(cnt)가 n-1보다 작은 경우, 연결되지 않은 대학이 존재한다는 뜻이므로 -1을 출력하고 그렇지 않은 경우엔 누적 거리(sumWeight)를 출력하면 된다.


전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <iostream> 
#include <vector> 
#pragma warning(disable:4996)
#pragma warning(disable:6031)
using namespace std;

#define MAX_N 1000
#define MAX_M 10000

typedef struct {
	int node1, node2, weight;
} Edge;

Edge tree[MAX_M];
int prnt[MAX_N + 1];
int gender[MAX_N + 1] = { 0 };

int compare(const void *a, const void *b);
int getRoot(int num);

int main() {

	int n, m, cnt = 0, sumWeight = 0;
	scanf("%d %d", &n, &m); getchar();

	for (int i = 1; i <= n; i++) {
		prnt[i] = i;

		char c;
		scanf("%c", &c); getchar();
		if (c == 'W') gender[i] = 1;
	}

	for (int i = 0; i < m; i++) {
		int u, v, w;
		scanf("%d %d %d", &u, &v, &w);

		tree[i].node1 = u, tree[i].node2 = v;
		tree[i].weight = w;
	}

	qsort(tree, m, sizeof(Edge), compare);

	for (int i = 0; i < m && cnt < n; i++) {
		int node1 = tree[i].node1;
		int node2 = tree[i].node2;
		int rootNode1 = getRoot(node1);
		int rootNode2 = getRoot(node2);
		if (rootNode1 != rootNode2 && gender[node1] != gender[node2]) {
			sumWeight += tree[i].weight;
			prnt[node1] = rootNode1;
			prnt[rootNode2] = rootNode1;
			cnt++;
		}
	}

	if (cnt < n - 1) printf("-1");
	else printf("%d", sumWeight);

	return 0;
}

int compare(const void *a, const void *b) {
	int x = (*(Edge *)a).weight;
	int y = (*(Edge *)b).weight;
	if (x > y) return 1;
	else if (x < y) return -1;
	return 0;
}

int getRoot(int num) {
	if (num != prnt[num]) {
		prnt[num] = getRoot(prnt[num]);
	}
	return prnt[num];
}

This post is licensed under CC BY 4.0 by the author.

[백준 17387번] 선분 교차 2 [C++]

[백준 2696번] 중앙값 구하기 [C++]