C++

UnionFindライブラリを整備する

ISO_C++_Logo

競技プログラミングでよく使われるUnionFind木を扱うライブラリを整備しました。

更新前ライブラリ

ABC304E問題解説)で紹介した UnionFind ライブラリは以下です。

これは、「蟻本」(「プログラミングコンテストチャレンジブック」第2版 秋葉拓哉、岩田陽一、北川宣稔著、マイナビ 2012年)を参考に実装しました。ただし、AtCoder が公開しているライブラリ AtCoder Library と合わせるため、メソッド名を変更しています。

#include <vector>

class UnionFind {
private:
	vector<int> par;
	vector<int> rank;

public:
	UnionFind(int n) {
		par.resize(n);
		rank.resize(n, 0);
		for (int i = 0; i < n; i++) {
			par[i] = i;
		}
	}

	int leader(int x) {
		if (par[x] == x) {
			return x;
		} else {
			return par[x] = leader(par[x]);
		}
	}

	void merge(int x, int y) {
		x = leader(x);
		y = leader(y);
		if (x == y) {
			return;
		}

		if (rank[x] < rank[y]) {
			par[x] = y;
		} else {
			par[y] = x;
			if (rank[x] == rank[y]) {
				++rank[x];
			}
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}
};

このライブラリは、問題なく動作していました。ただし、AtCoder Library と比較すると以下の機能がありません。

  • 頂点 x の属する連結成分のサイズを返す size(x)
  • グラフの連結成分を返す groups()

これらの機能を加えることで自前ライブラリを更新しました。

更新後ライブラリ

機能を追加するために既存コードに対して、以下の変更をしました。

  • 元のライブラリは、木の高さ(rank)が高い木に低い木を merge しました。これをサイズが大きい木にサイズが小さい木を merge するように変更しました。元のライブラリで使っていた配列 rank は削除します。
  • 内部配列 par を以下の役割に変更する。
    • 自分が木の根の場合は、その木のサイズにマイナス1を掛けた値を保持する。
    • 自分が木の根の場合は、属する木の根を保持する。
  • 宣言 using namespace std; が無くても動作するように、必要に応じて std:: を追加した。

関数 size は、属する木の根の par の値にマイナス1を掛けた値を返すだけです。このアイデアは、AtCoder Library の実装を参考にしました。同じ変数に対して、値の範囲で役割を変えることは一般的に良くないと考えられています。ライブラリ内部の閉じた変数なので許容することにします。

関数 groups は、木の根を添え字に持つ Vector コンテナ member にすべての要素を格納した後に、実際にメンバーが存在する Vector コンテナのみを result に格納して、これを返します。この実装は、けんちょんさんのブログ記事を参考にしました。

以下が、更新後の UnionFind ライブラリとなります。

#include <algorithm>    // swap
#include <vector>

class UnionFind {
private:
	int n;
	std::vector<int> par;

public:
	UnionFind(int _n) {
		n = _n;
		par.resize(n, -1);
	}

	int leader(int x) {
		if (par[x] < 0) {
			return x;
		} else {
			return par[x] = leader(par[x]);
		}
	}

	void merge(int x, int y) {
		x = leader(x);
		y = leader(y);
		if (x != y) {
			if (-par[x] < -par[y]) {
				std::swap(x, y);
			}
			par[x] += par[y];
			par[y] = x;
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}

	int size(int x) {
		return -par[leader(x)];
	}

	std::vector<std::vector<int>> groups() {
		std::vector<std::vector<int>> member(n);
		for (int i = 0; i < n; ++i) {
			member[leader(i)].push_back(i);
		}

		std::vector<std::vector<int>> result;
		for (int i = 0; i < n; ++i) {
			if (!member[i].empty()) {
				result.push_back(member[i]);
			}
		}

		return result;
	}
};

ライブラリのテスト

AtCoder Library Practice Contest でライブラリをテストします。以下が問題のリンク先です。

A問題 Disjoint Set Union

この問題を解くプログラムは以下です。AC判定となります。このプログラムは、AOJ のコース DSL 1_A問題でもACとなります。

#include <iostream>
#include <algorithm>    // swap
#include <vector>

using namespace std;

class UnionFind {
private:
	int n;
	std::vector<int> par;

public:
	UnionFind(int _n) {
		n = _n;
		par.resize(n, -1);
	}

	int leader(int x) {
		if (par[x] < 0) {
			return x;
		} else {
			return par[x] = leader(par[x]);
		}
	}

	void merge(int x, int y) {
		x = leader(x);
		y = leader(y);
		if (x != y) {
			if (-par[x] < -par[y]) {
				std::swap(x, y);
			}
			par[x] += par[y];
			par[y] = x;
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}

	int size(int x) {
		return -par[leader(x)];
	}

	std::vector<std::vector<int>> groups() {
		std::vector<std::vector<int>> member(n);
		for (int i = 0; i < n; ++i) {
			member[leader(i)].push_back(i);
		}

		std::vector<std::vector<int>> result;
		for (int i = 0; i < n; ++i) {
			if (!member[i].empty()) {
				result.push_back(member[i]);
			}
		}

		return result;
	}
};

int main()
{
	int n, q;
	cin >> n >> q;

	UnionFind uf(n);
	for (int i = 0; i < q; ++i) {
		int type;
		int x, y;
		cin >> type >> x >> y;
		if (type == 0) {
			uf.merge(x, y);
		} else if (type == 1) {
			if (uf.same(x, y)) {
				cout << 1 << endl;
			} else {
				cout << 0 << endl;
			}
		}
	}

	return 0;
}

もちろん次のように ACL を使えば、プログラムを簡潔に書くことできます。関係するコードの背景色を変更しました。

#include <iostream>
#include <atcoder/dsu>

using namespace std;
using namespace atcoder;

int main()
{
	int n, q;
	cin >> n >> q;

	dsu uf(n);
	for (int i = 0; i < q; ++i) {
		int type;
		int x, y;
		cin >> type >> x >> y;
		if (type == 0) {
			uf.merge(x, y);
		} else if (type == 1) {
			if (uf.same(x, y)) {
				cout << 1 << endl;
			} else {
				cout << 0 << endl;
			}
		}
	}

	return 0;
}

最後に

UnionFindは、競プロ向けに用意した初めてのライブラリでした。関係する問題を解いて、手になじむ道具にしていきたいです。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA