UnionFind木を重み付きに拡張しました。
復習)UnionFind
このブログで紹介したUnionFindライブラリは、以下です。
メソッド名は、AtCoder が公開しているライブラリ AtCoder Library と合わせています。
#include <algorithm> // swap
#include <vector>
class UnionFind {
private:
int n;
std::vector<int> par;
public:
UnionFind(int _n) {
n = _n;
par.resize(n, -1);
}
int leader(int x) {
if (par[x] < 0) {
return x;
} else {
return par[x] = leader(par[x]);
}
}
void merge(int x, int y) {
x = leader(x);
y = leader(y);
if (x != y) {
if (-par[x] < -par[y]) {
std::swap(x, y);
}
par[x] += par[y];
par[y] = x;
}
}
bool same(int x, int y) {
return leader(x) == leader(y);
}
int size(int x) {
return -par[leader(x)];
}
std::vector<std::vector<int>> groups() {
std::vector<std::vector<int>> member(n);
for (int i = 0; i < n; ++i) {
member[leader(i)].push_back(i);
}
std::vector<std::vector<int>> result;
for (int i = 0; i < n; ++i) {
if (!member[i].empty()) {
result.push_back(member[i]);
}
}
return result;
}
};
重み付きUnionFind
上記のUnionFindライブラリを拡張して、頂点の間の重さを管理できるようにします。以下の実装は、けんちょんさんのQiitaの記事を参考にしました。
それぞれの連結成分の根の重さを0とします。根と異なる頂点の重さは、連結するmergeメソッドで指定した値を用いて更新します。重さの型は、long long int (ll) としました。
変更または追加するメソッドは以下となります。
void merge(x, y, z) | 頂点xと頂点yを連結する。 weight(y) = weight(x) + w となるように重さを付ける。 |
ll weight(x) | 頂点xの重さを返す。 |
ll diff(x, y) | weight(y) – weight(x) を返す。 |
実装についてのコメントは以下です。
- Privateなメンバとして、配列 diff_weight を追加します。これはひとつ上の親との重さの差を格納します。
- 連結成分の根を求める leader メソッド呼び出し時に経路圧縮を行います。経路圧縮に合わせて、再帰的に diff_weight を更新します。結果的に leader 呼び出し後は、根と頂点との重さの差が diff_weight に格納されます。
- weight メソッドは、leader を呼び出した後に diff_weight を返します。
- merge メソッドでは、leader(x) と leader(y) を連結するため、指定された重さの差 w を調整します。サイズが大きい方が親になるように連結します。このときに必要に応じて、wを-wに置き換えます。
ライブラリの実装
UnionFindからのポイントとなる差分の背景色を変更しました。
#include <algorithm> // swap
#include <vector>
typedef long long int ll;
class Weighted_UnionFind {
private:
int n;
std::vector<int> par;
std::vector<ll> diff_weight;
public:
Weighted_UnionFind(int _n) {
n = _n;
par.resize(n, -1);
diff_weight.resize(n, 0);
}
int leader(int x) {
if (par[x] < 0) {
return x;
} else {
int r = leader(par[x]);
diff_weight[x] += diff_weight[par[x]];
return par[x] = r;
}
}
ll weight(int x) {
leader(x);
return diff_weight[x];
}
ll diff(int x, int y) {
return weight(y) - weight(x);
}
void merge(int x, int y, ll w) {
w += weight(x);
w -= weight(y);
x = leader(x);
y = leader(y);
if (x != y) {
if (-par[x] < -par[y]) {
std::swap(x, y);
w = -w;
}
par[x] += par[y];
par[y] = x;
diff_weight[y] = w;
}
}
bool same(int x, int y) {
return leader(x) == leader(y);
}
int size(int x) {
return -par[leader(x)];
}
std::vector<std::vector<int>> groups() {
std::vector<std::vector<int>> member(n);
for (int i = 0; i < n; ++i) {
member[leader(i)].push_back(i);
}
std::vector<std::vector<int>> result;
for (int i = 0; i < n; ++i) {
if (!member[i].empty()) {
result.push_back(member[i]);
}
}
return result;
}
};
ライブラリのテスト
AIZU ONLINE JUDGE で出題されている問題でライブラリをテストします。以下が問題のリンク先です。
DSL 1_B問題 Weighted Union Find Trees
この問題を解くプログラムは以下です。AC判定となります。
#include <iostream>
#include <algorithm> // swap
#include <vector>
using namespace std;
typedef long long int ll;
class Weighted_UnionFind {
private:
int n;
std::vector<int> par;
std::vector<ll> diff_weight;
public:
Weighted_UnionFind(int _n) {
n = _n;
par.resize(n, -1);
diff_weight.resize(n, 0);
}
int leader(int x) {
if (par[x] < 0) {
return x;
} else {
int r = leader(par[x]);
diff_weight[x] += diff_weight[par[x]];
return par[x] = r;
}
}
ll weight(int x) {
leader(x);
return diff_weight[x];
}
ll diff(int x, int y) {
return weight(y) - weight(x);
}
void merge(int x, int y, ll w) {
w += weight(x);
w -= weight(y);
x = leader(x);
y = leader(y);
if (x != y) {
if (-par[x] < -par[y]) {
std::swap(x, y);
w = -w;
}
par[x] += par[y];
par[y] = x;
diff_weight[y] = w;
}
}
bool same(int x, int y) {
return leader(x) == leader(y);
}
int size(int x) {
return -par[leader(x)];
}
std::vector<std::vector<int>> groups() {
std::vector<std::vector<int>> member(n);
for (int i = 0; i < n; ++i) {
member[leader(i)].push_back(i);
}
std::vector<std::vector<int>> result;
for (int i = 0; i < n; ++i) {
if (!member[i].empty()) {
result.push_back(member[i]);
}
}
return result;
}
};
int main()
{
int n, q;
cin >> n >> q;
Weighted_UnionFind uf(n);
for (int i = 0; i < q; ++i) {
int type;
int x, y;
cin >> type >> x >> y;
if (type == 0) {
ll z;
cin >> z;
uf.merge(x, y, z);
} else if (type == 1) {
if (uf.same(x, y)) {
cout << uf.diff(x, y) << endl;
} else {
cout << "?" << endl;
}
}
}
return 0;
}
最後に
ABC328 F問題のために、重み付きUnionFindを整備しました。少しずつ道具も増やしていきたいです。