C++

重み付きUnionFindライブラリを整備する

ISO_C++_Logo

UnionFind木を重み付きに拡張しました。

復習)UnionFind

このブログで紹介したUnionFindライブラリは、以下です。

メソッド名は、AtCoder が公開しているライブラリ AtCoder Library と合わせています。

#include <algorithm>    // swap
#include <vector>

class UnionFind {
private:
	int n;
	std::vector<int> par;

public:
	UnionFind(int _n) {
		n = _n;
		par.resize(n, -1);
	}

	int leader(int x) {
		if (par[x] < 0) {
			return x;
		} else {
			return par[x] = leader(par[x]);
		}
	}

	void merge(int x, int y) {
		x = leader(x);
		y = leader(y);
		if (x != y) {
			if (-par[x] < -par[y]) {
				std::swap(x, y);
			}
			par[x] += par[y];
			par[y] = x;
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}

	int size(int x) {
		return -par[leader(x)];
	}

	std::vector<std::vector<int>> groups() {
		std::vector<std::vector<int>> member(n);
		for (int i = 0; i < n; ++i) {
			member[leader(i)].push_back(i);
		}

		std::vector<std::vector<int>> result;
		for (int i = 0; i < n; ++i) {
			if (!member[i].empty()) {
				result.push_back(member[i]);
			}
		}

		return result;
	}
};

重み付きUnionFind

上記のUnionFindライブラリを拡張して、頂点の間の重さを管理できるようにします。以下の実装は、けんちょんさんのQiitaの記事を参考にしました。

それぞれの連結成分の根の重さを0とします。根と異なる頂点の重さは、連結するmergeメソッドで指定した値を用いて更新します。重さの型は、long long int (ll) としました。

変更または追加するメソッドは以下となります。

void merge(x, y, z)頂点xと頂点yを連結する。
weight(y) = weight(x) + w となるように重さを付ける。
ll weight(x)頂点xの重さを返す。
ll diff(x, y)weight(y) – weight(x) を返す。

実装についてのコメントは以下です。

  • Privateなメンバとして、配列 diff_weight を追加します。これはひとつ上の親との重さの差を格納します。
  • 連結成分の根を求める leader メソッド呼び出し時に経路圧縮を行います。経路圧縮に合わせて、再帰的に diff_weight を更新します。結果的に leader 呼び出し後は、根と頂点との重さの差が diff_weight に格納されます。
  • weight メソッドは、leader を呼び出した後に diff_weight を返します。
  • merge メソッドでは、leader(x) と leader(y) を連結するため、指定された重さの差 w を調整します。サイズが大きい方が親になるように連結します。このときに必要に応じて、wを-wに置き換えます。

ライブラリの実装

UnionFindからのポイントとなる差分の背景色を変更しました。

#include <algorithm>    // swap
#include <vector>

typedef long long int ll;

class Weighted_UnionFind {
private:
	int n;
	std::vector<int> par;
	std::vector<ll> diff_weight;

public:
	Weighted_UnionFind(int _n) {
		n = _n;
		par.resize(n, -1);
		diff_weight.resize(n, 0);
	}

	int leader(int x) {
		if (par[x] < 0) {
			return x;
		} else {
			int r = leader(par[x]);
			diff_weight[x] += diff_weight[par[x]];
			return par[x] = r;
		}
	}

	ll weight(int x) {
		leader(x);
		return diff_weight[x];
	}

	ll diff(int x, int y) {
		return weight(y) - weight(x);
	}

	void merge(int x, int y, ll w) {
		w += weight(x);
		w -= weight(y);
		x = leader(x);
		y = leader(y);
		if (x != y) {
			if (-par[x] < -par[y]) {
				std::swap(x, y);
				w = -w;
			}
			par[x] += par[y];
			par[y] = x;
			diff_weight[y] = w;
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}

	int size(int x) {
		return -par[leader(x)];
	}

	std::vector<std::vector<int>> groups() {
		std::vector<std::vector<int>> member(n);
		for (int i = 0; i < n; ++i) {
			member[leader(i)].push_back(i);
		}

		std::vector<std::vector<int>> result;
		for (int i = 0; i < n; ++i) {
			if (!member[i].empty()) {
				result.push_back(member[i]);
			}
		}

		return result;
	}
};

ライブラリのテスト

AIZU ONLINE JUDGE で出題されている問題でライブラリをテストします。以下が問題のリンク先です。

DSL 1_B問題 Weighted Union Find Trees

この問題を解くプログラムは以下です。AC判定となります。

#include <iostream>
#include <algorithm>    // swap
#include <vector>

using namespace std;

typedef long long int ll;

class Weighted_UnionFind {
private:
	int n;
	std::vector<int> par;
	std::vector<ll> diff_weight;

public:
	Weighted_UnionFind(int _n) {
		n = _n;
		par.resize(n, -1);
		diff_weight.resize(n, 0);
	}

	int leader(int x) {
		if (par[x] < 0) {
			return x;
		} else {
			int r = leader(par[x]);
			diff_weight[x] += diff_weight[par[x]];
			return par[x] = r;
		}
	}

	ll weight(int x) {
		leader(x);
		return diff_weight[x];
	}

	ll diff(int x, int y) {
		return weight(y) - weight(x);
	}

	void merge(int x, int y, ll w) {
		w += weight(x);
		w -= weight(y);
		x = leader(x);
		y = leader(y);
		if (x != y) {
			if (-par[x] < -par[y]) {
				std::swap(x, y);
				w = -w;
			}
			par[x] += par[y];
			par[y] = x;
			diff_weight[y] = w;
		}
	}

	bool same(int x, int y) {
		return leader(x) == leader(y);
	}

	int size(int x) {
		return -par[leader(x)];
	}

	std::vector<std::vector<int>> groups() {
		std::vector<std::vector<int>> member(n);
		for (int i = 0; i < n; ++i) {
			member[leader(i)].push_back(i);
		}

		std::vector<std::vector<int>> result;
		for (int i = 0; i < n; ++i) {
			if (!member[i].empty()) {
				result.push_back(member[i]);
			}
		}

		return result;
	}
};

int main()
{
	int n, q;
	cin >> n >> q;

	Weighted_UnionFind uf(n);
	for (int i = 0; i < q; ++i) {
		int type;
		int x, y;
		cin >> type >> x >> y;
		if (type == 0) {
			ll z;
			cin >> z;
			uf.merge(x, y, z);
		} else if (type == 1) {
			if (uf.same(x, y)) {
				cout << uf.diff(x, y) << endl;
			} else {
				cout << "?" << endl;
			}
		}
	}

	return 0;
}

最後に

ABC328 F問題のために、重み付きUnionFindを整備しました。少しずつ道具も増やしていきたいです。

COMMENT

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA